diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/JEPA.mp4 b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/JEPA.mp4 new file mode 100644 index 0000000000..fc76e2eb4c Binary files /dev/null and b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/JEPA.mp4 differ diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/JEPA_BIJEPAXLITE_3SEED_STATUS.md b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/JEPA_BIJEPAXLITE_3SEED_STATUS.md new file mode 100644 index 0000000000..263c5c1351 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/JEPA_BIJEPAXLITE_3SEED_STATUS.md @@ -0,0 +1,63 @@ +# BIJEPAX-lite 3-seed candidate + +Config matches successful seed 42 run: +- script: our_submission/train_gpt_v15_bijepax.py +- DISABLE_COMPILE=1 +- CASEOPS_ENABLED=1 +- PPM_MIXER_ENABLED=1 order=5 H=0.999 L=0.18 T=0.80 +- TTT_ENABLED=0 +- LQER_TOP_K=1 +- BIJEPAX_ENABLED=1 weight=0.01 start=0.35 end=0.80 fwd_hops=4 bwd_hops=4 cycle=0 head_dim=32 stride=64 lr=0.001 + +Existing seed 42: +- run: v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405 +- final ppm_sliding val_bpb: 0.97234287 +- artifact bytes: 15997180 +- eval time: 502131ms +- rc: 0 + +Queued seeds: 314, 999 + + +## v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209 +- started: 2026-05-01T02:52:09Z +- log: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log +- finished: 2026-05-01T03:14:07Z +- rc: 0 +- scores: diagnostic pre-quantization post-ema val_loss:2.42899363 val_bpb:1.10988323 eval_time:9910ms;Total submission size quantized+pergroup: 15999539 bytes;diagnostic quantized val_loss:2.44155528 val_bpb:1.11562304 eval_time:9926ms;ppm_mixer val_bpb:0.97206308 eval_time:453715ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499;ppm_sliding val_loss:2.45044876 val_bpb:0.97206308 eval_time:499038ms; + +## v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407 +- started: 2026-05-01T03:14:07Z +- log: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log +- finished: 2026-05-01T03:36:01Z +- rc: 0 +- scores: diagnostic pre-quantization post-ema val_loss:2.43314506 val_bpb:1.11178015 eval_time:9911ms;Total submission size quantized+pergroup: 15997593 bytes;diagnostic quantized val_loss:2.44582432 val_bpb:1.11757370 eval_time:11393ms;ppm_mixer val_bpb:0.97373767 eval_time:451054ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499;ppm_sliding val_loss:2.45502055 val_bpb:0.97373767 eval_time:496384ms; + +## Final scrape +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log: artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log: logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209.txt +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log: model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/final_model.pt +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log: quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/final_model.int6.ptz +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log: run_id: v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log:Total submission size quantized+pergroup: 15999539 bytes +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log:diagnostic quantized val_loss:2.44155528 val_bpb:1.11562304 eval_time:9926ms +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log:ppm_mixer val_bpb:0.97206308 eval_time:453715ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/train.log:ppm_sliding val_loss:2.45044876 val_bpb:0.97206308 eval_time:499038ms +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log: artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log: logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405.txt +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log: model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/final_model.pt +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log: quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/final_model.int6.ptz +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log: run_id: v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log:Total submission size quantized+pergroup: 15997180 bytes +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log:diagnostic quantized val_loss:2.44116551 val_bpb:1.11544494 eval_time:10342ms +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log:ppm_mixer val_bpb:0.97234287 eval_time:456845ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/train.log:ppm_sliding val_loss:2.45118426 val_bpb:0.97234287 eval_time:502131ms +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log: artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log: logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407.txt +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log: model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/final_model.pt +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log: quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/final_model.int6.ptz +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log: run_id: v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log:Total submission size quantized+pergroup: 15997593 bytes +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log:diagnostic quantized val_loss:2.44582432 val_bpb:1.11757370 eval_time:11393ms +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log:ppm_mixer val_bpb:0.97373767 eval_time:451054ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +/workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/train.log:ppm_sliding val_loss:2.45502055 val_bpb:0.97373767 eval_time:496384ms diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/LEGALITY_AUDIT.md b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/LEGALITY_AUDIT.md new file mode 100644 index 0000000000..63b052c78d --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/LEGALITY_AUDIT.md @@ -0,0 +1,115 @@ +# Legality Audit - BIJEPAX-lite + +## Verdict + +Current read: **likely legal/submittable**, assuming the existing CaseOps byte-sidecar/PPM lane is accepted. + +The BIJEPAX-lite addition itself is low-risk because it is training-only and has no evaluation-time access to future validation tokens. + +## Challenge rules checked + +From the repository README: + +- Submission artifact size is code bytes plus compressed model bytes. +- The cap is strict decimal `16,000,000` bytes. +- Evaluation may not use training data unless paid for inside the artifact. +- Validation data may not be used during training. +- Evaluation must complete within 10 minutes on 8xH100, separate from the 10-minute training cap. +- Test-time methods must score before updating on validation tokens. + +## Artifact size + +Seed 42: + +- `Serialized model quantized+pergroup: 15955181 bytes` +- `Total submission size quantized+pergroup: 15997180 bytes` +- Strict cap: `16000000 bytes` +- Headroom: `2820 bytes` + +This is tight but under cap. + +`LQER_TOP_K=1` was used specifically to create byte headroom. Earlier BIJEPA without this trim packaged at `16,004,902` bytes and was not submittable. + +## Training-only JEPA auxiliary + +Relevant implementation: + +- `class MultiDirectionalBiJEPAX` +- `def bijepax_weight_at` +- `train_model(...): bijepax_module = MultiDirectionalBiJEPAX(...)` +- `step_fn(...): loss = ce_loss + bijepax_module(hidden, ...)` + +The predictor module is created outside `base_model`: + +```python +bijepax_module = MultiDirectionalBiJEPAX(...).to(device).bfloat16() +bijepax_opt = torch.optim.Adam(bijepax_module.parameters(), ...) +``` + +It is not assigned as a child module of `base_model`, so `base_model.state_dict()` does not contain BIJEPAX predictor weights. + +Serialization only saves `base_model.state_dict()`: + +```python +torch.save(base_model.state_dict(), h.model_path) +sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) +``` + +So the JEPA predictor heads are not present in the final artifact. + +## No validation leakage during training + +Training batches come from `DocumentPackingLoader(h, device)`. + +Validation data is loaded for periodic/terminal validation, but the BIJEPAX training loss only uses hidden states from training microbatches: + +```python +x, y, cu_seqlens, _max_seqlen = train_loader.next_batch(...) +ce_loss, hidden = forward_with_hidden(x, y, ...) +loss = ce_loss + bijepax_module(hidden, ...) +``` + +The BIJEPAX module does not read validation tokens, validation bytes, or validation sidecars. + +## Evaluation path + +Final score uses the existing PPM sliding evaluator: + +- `eval_val_ppm_sliding` +- `ppm_mixer val_bpb` +- `ppm_sliding val_loss / val_bpb` + +The PPM mixer operates score-before-update over the scored target stream. The implementation computes neural log probabilities first, then the byte mixer walks bytes in order and updates its tables after scoring each byte. + +Legality risk is therefore concentrated in whether reviewers accept this existing PPM/CaseOps scoring lane, not in BIJEPAX-lite. + +## Cross-document leak check + +The SmearGate cross-document leak fix is present in both hidden and TTT paths: + +```python +not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) +x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) +``` + +TTT is disabled for this candidate (`TTT_ENABLED=0`), but the symmetric fix is still present. + +## Eval compile + +The run uses `DISABLE_COMPILE=1`. Post-serialize evaluation also honors this: + +```python +if os.environ.get("DISABLE_COMPILE", "0") == "1": + log("eval_compile:disabled_by_env") + compiled_model = eval_model + compiled_forward_logits = eval_model.forward_logits +``` + +This avoids the compile stall encountered in the first BIJEPAX attempts. + +## Risks / reviewer-facing caveats + +- The artifact headroom is only `2820` bytes on seed 42. Do not add substantial code unless compression is rechecked. +- The PR should avoid unverifiable claims such as "BiJEPA proved 4x better on chaotic systems" unless the exact source is provided. +- The submission should clearly say the JEPA module is an auxiliary training regularizer, not an eval-time bidirectional predictor. +- If the competition reviewers consider the PPM/CaseOps byte-sidecar lane non-compliant, this candidate inherits that risk. diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/PR_BODY.md b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/PR_BODY.md new file mode 100644 index 0000000000..b0fdfc2b5b --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/PR_BODY.md @@ -0,0 +1,69 @@ +# BIJEPAX-lite JEPA + SP8192 CaseOps PPM + +This record submits a Claude-designed, JEPA-inspired training-only auxiliary regularizer on top of the SP8192 CaseOps + per-group compression + PPM sliding stack. + +The final 3-seed mean is: + +```text +ppm_sliding val_bpb: 0.97271454 +``` + +## Results + +| Seed | Final `ppm_sliding val_bpb` | Quantized diagnostic | Artifact bytes | Train stop | Eval time | Exit | +|---:|---:|---:|---:|---:|---:|---:| +| 42 | `0.97234287` | `1.11544494` | `15,997,180` | `2014` steps / `599.843s` | `502.131s` | `0` | +| 314 | `0.97206308` | `1.11562304` | `15,999,539` | `2012` steps / `599.586s` | `499.038s` | `0` | +| 999 | `0.97373767` | `1.11757370` | `15,997,593` | `2013` steps / `599.821s` | `496.384s` | `0` | + +Three-seed sample std: `0.00089703`. + +All three runs are under: + +- strict decimal `16,000,000` byte artifact cap +- 600s training cap +- 600s evaluation cap + +## What is new + +BIJEPAX-lite adds a small custom JEPA-style hidden-state prediction objective during training: + +- hop-4 forward hidden-state prediction +- hop-4 backward hidden-state prediction +- cosine embedding-space loss +- LayerNorm-stabilized predictor heads +- no cycle head in the submitted lightweight config +- active only from `35%` to `80%` of the wallclock schedule +- separate optimizer and separate module from the base GPT + +The predictor heads are **not serialized**. Final scoring is performed by the quantized base model with the existing causal PPM sliding evaluator. + +## Compliance notes + +- `TTT_ENABLED=0` +- `LQER_TOP_K=1` keeps all seeds below the strict byte cap +- SmearGate BOS masking is present for packed-document cross-boundary safety +- BIJEPAX-lite trains only on training batches from `DocumentPackingLoader` +- BIJEPAX-lite does not access validation tokens or validation byte sidecars during training +- Final score is from `ppm_sliding` + +The folder includes: + +- `train_gpt.py` +- three seed logs +- full source/log captures for each seed +- `submission.json` +- `LEGALITY_AUDIT.md` +- `STATIC_AUDIT_NOTES.md` +- `REFERENCES.md` +- `JEPA.mp4` as a short visual/demo asset + +## Acknowledgements + +Thanks to Claude for designing the custom BIJEPAX-lite auxiliary objective and helping turn the JEPA idea into a runnable candidate. Thanks to Codex for implementing the run path, auditing legality, coordinating the 3-seed package, and assembling this PR. Thanks also to the Parameter Golf community for the public ideas and fast iteration that this stack builds on. + +## Validation + +- `python3 -m py_compile records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_gpt.py` +- `python3 -m json.tool records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/submission.json` +- 3 full remote runs on 8xH100 completed with `rc=0` diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/README.md b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/README.md new file mode 100644 index 0000000000..88d63b8d99 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/README.md @@ -0,0 +1,149 @@ +# BIJEPAX-lite JEPA + SP8192 CaseOps PPM + +This record captures a Claude-designed, JEPA-inspired training regularizer layered onto the SP8192 CaseOps + per-group compression + PPM sliding stack. + +The key idea is deliberately narrow: add a small bidirectional hidden-state prediction objective during training, then remove it entirely from the artifact. The final submitted model is still the quantized base GPT, scored causally by the PPM sliding evaluator. + +Thanks to Claude for designing the BIJEPAX-lite auxiliary objective and helping shape the experiment, to Codex for implementation, auditing, packaging, and run coordination, and to the Parameter Golf community for the public ideas this stack builds on. The inherited stack uses public Parameter Golf work by @clarkkev, @bigbag, @codemath3000, @OE-GOD, @remg1997, @joshuaswanson, @MarioPaerle, @classiclarryd, @simonbissonnette, @dexhunter, @romeerp, @samacqua, @renqianluo, @jorge-asenjo, @Omrigotlieb, @AnirudhRahul, @ndokutovich, and @H1cSuNtDr4C0n3S. See `REFERENCES.md` for the detailed lineage. + +## Score + +| Seed | Final `ppm_sliding val_bpb` | Quantized diagnostic | Artifact bytes | Train stop | Eval time | Exit | +|---:|---:|---:|---:|---:|---:|---:| +| 42 | `0.97234287` | `1.11544494` | `15,997,180` | `2014` steps / `599.843s` | `502.131s` | `0` | +| 314 | `0.97206308` | `1.11562304` | `15,999,539` | `2012` steps / `599.586s` | `499.038s` | `0` | +| 999 | `0.97373767` | `1.11757370` | `15,997,593` | `2013` steps / `599.821s` | `496.384s` | `0` | + +Three-seed mean: + +```text +0.97271454 +``` + +Sample standard deviation: + +```text +0.00089703 +``` + +All three runs are under the strict decimal `16,000,000` byte cap and under the 600s eval cap. + +## What changed + +BIJEPAX-lite adds a training-only auxiliary module: + +- hop-4 forward hidden-state prediction +- hop-4 backward hidden-state prediction +- cosine embedding-space loss +- LayerNorm-stabilized two-layer predictor heads +- no cycle head in the submitted lightweight config +- active from `35%` to `80%` of the wallclock training schedule +- separate optimizer for the auxiliary predictor +- predictor heads are never serialized into the artifact + +## Lineage and attribution + +The submitted JEPA branch is a custom training-objective experiment on top of an inherited Parameter Golf stack: + +- SP8192 tokenizer, recurrence, QK gain, and compact GPT training lineage from PR #1394, PR #1493, and PR #1855. +- Causal byte-PPM mixer lineage from PR #1795, PR #1959, and PR #1991. +- SmearGate / attention output gate lineage from modded-nanogpt @classiclarryd and PR #1667, plus the BOS cross-document leak fix discussed in PR #2014 / the PR #1797 base audit. +- Per-group `lrzip` compression lineage from PR #1586 through PR #1667 / PR #1729-style grouped serialization work. +- LQER/AWQ/asymmetric-rescale and related quantization/optimization pieces from PR #1530, PR #1797, PR #1886, PR #1923, and PR #1855. +- JEPA-Lite local-competition precedent from PR #2027. +- Online n-gram tilt / scoring overlay ideas from PR #1145 and PR #1967, although the submitted score uses the PPM path with TTT disabled. + +Our specific addition is the Claude-designed BIJEPAX-lite training-only auxiliary objective: bidirectional hop-4 hidden-state prediction with cosine loss and LayerNorm-stabilized predictor heads, removed from the serialized artifact. + +Submitted config: + +```bash +DISABLE_COMPILE=1 +CASEOPS_ENABLED=1 +VOCAB_SIZE=8192 +TTT_ENABLED=0 +PPM_MIXER_ENABLED=1 +PPM_ORDER=5 +PPM_H=0.999 +PPM_L=0.18 +PPM_T=0.80 +LQER_TOP_K=1 +BIJEPAX_ENABLED=1 +BIJEPAX_WEIGHT=0.01 +BIJEPAX_START_FRAC=0.35 +BIJEPAX_END_FRAC=0.80 +BIJEPAX_FWD_HOPS=4 +BIJEPAX_BWD_HOPS=4 +BIJEPAX_CYCLE=0 +BIJEPAX_HEAD_DIM=32 +BIJEPAX_STRIDE=64 +BIJEPAX_LR=0.001 +``` + +`LQER_TOP_K=1` is used to preserve artifact headroom. Seed 314 is the tightest artifact at `15,999,539` bytes. + +## Legality notes + +The JEPA component itself is training-only: + +- `MultiDirectionalBiJEPAX` is created as a standalone module, not as a child of the base GPT. +- Serialization saves `base_model.state_dict()`, so BIJEPAX predictor weights are absent from `final_model.int6.ptz`. +- The auxiliary loss uses hidden states from training batches only. +- `TTT_ENABLED=0`, so there is no validation-set test-time training in this record. +- SmearGate BOS masking is present to avoid packed-document cross-boundary leakage. + +The final score uses the existing causal PPM sliding path. The inherited compliance risk, if any, is the existing SP8192 CaseOps + PPM byte-sidecar lane rather than the BIJEPAX-lite auxiliary regularizer. + +See `LEGALITY_AUDIT.md` and `STATIC_AUDIT_NOTES.md` for the audit details. + +## References + +This is a custom Claude-designed JEPA-inspired auxiliary objective, not a claimed faithful reproduction of a specific BiJEPA paper. + +The design is inspired by: + +- I-JEPA: https://arxiv.org/abs/2301.08243 +- LLM-JEPA: https://arxiv.org/abs/2509.14252 +- MC-JEPA: https://arxiv.org/abs/2307.12698 + +See `REFERENCES.md` for wording notes and claims we intentionally avoid. + +## Included files + +- `train_gpt.py` - exact script used for the submitted runs +- `train_seed42.log`, `train_seed314.log`, `train_seed999.log` - run logs +- `full_seed42.txt`, `full_seed314.txt`, `full_seed999.txt` - full captured source/log snapshots +- `submission.json` - leaderboard metadata +- `LEGALITY_AUDIT.md` - compliance audit +- `STATIC_AUDIT_NOTES.md` - static code review notes +- `REFERENCES.md` - prior work and PR wording guidance +- `JEPA.mp4` - short visual/demo asset for the PR + +## Exact final lines + +Seed 42: + +```text +Total submission size quantized+pergroup: 15997180 bytes +diagnostic quantized val_loss:2.44116551 val_bpb:1.11544494 eval_time:10342ms +ppm_mixer val_bpb:0.97234287 eval_time:456845ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45118426 val_bpb:0.97234287 eval_time:502131ms +``` + +Seed 314: + +```text +Total submission size quantized+pergroup: 15999539 bytes +diagnostic quantized val_loss:2.44155528 val_bpb:1.11562304 eval_time:9926ms +ppm_mixer val_bpb:0.97206308 eval_time:453715ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45044876 val_bpb:0.97206308 eval_time:499038ms +``` + +Seed 999: + +```text +Total submission size quantized+pergroup: 15997593 bytes +diagnostic quantized val_loss:2.44582432 val_bpb:1.11757370 eval_time:11393ms +ppm_mixer val_bpb:0.97373767 eval_time:451054ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45502055 val_bpb:0.97373767 eval_time:496384ms +``` diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/REFERENCES.md b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/REFERENCES.md new file mode 100644 index 0000000000..4c5091e695 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/REFERENCES.md @@ -0,0 +1,63 @@ +# References and Prior Work + +Use this file for claims we can defend in the PR. + +## Primary references + +1. **I-JEPA: Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture** + - arXiv: https://arxiv.org/abs/2301.08243 + - Authors: Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas. + - Relevant claim: JEPA predicts target representations from context representations instead of reconstructing input pixels/tokens. + +2. **LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures** + - arXiv: https://arxiv.org/abs/2509.14252 + - Authors: Hai Huang, Yann LeCun, Randall Balestriero. + - Relevant claim: JEPA-style embedding-space objectives can be combined with LLM training objectives and reported stronger/robuster performance across multiple text tasks/models. + +3. **MC-JEPA: A Joint-Embedding Predictive Architecture for Self-Supervised Learning of Motion and Content Features** + - arXiv: https://arxiv.org/abs/2307.12698 + - Authors: Adrien Bardes, Jean Ponce, Yann LeCun. + - Relevant claim: JEPA objectives can be composed with another auxiliary objective in a shared encoder; this is a useful analogy for our "CE plus training-only representation prediction" setup. + +4. **Parameter Golf README / rules** + - GitHub: https://github.com/openai/parameter-golf/blob/main/README.md + - Relevant constraints: strict decimal 16,000,000-byte artifact cap; no validation data during training; eval under 10 minutes; record submissions normally need multi-seed evidence. + +## Competition-local ideas inherited or combined + +The BIJEPAX-lite submission is not a from-scratch stack. It layers a JEPA-style training objective onto our existing SP8192 CaseOps + PPM stack. + +Important inherited components to acknowledge: + +- PR #1394 by @clarkkev: SP8192, GPTQ embeddings, depth recurrence, MuonEq-R, and related compact GPT training lineage. +- PR #1493 by @bigbag: SP8192 plus 3-layer recurrence, parallel residuals, QK gain 5.25, and the stronger recurrent base used by later SP8192 submissions. +- PR #1855 by @codemath3000: SP8192 plus LQER, sparse attention gate, BOS-fixed SmearGate, and the greedy hyperparameter stack. +- PR #1795 by @OE-GOD: strict-legal causal byte-level PPM adaptive-lambda mixer. +- PR #1959 by @remg1997: SP8192 plus byte-PPM mixer. +- PR #1991 by @joshuaswanson: SP8192 + byte-PPM tuned order/gate, including the order-5 PPM direction used here. +- modded-nanogpt @classiclarryd and PR #1667 by @MarioPaerle: SmearGate / attention-output-gate lineage. +- PR #1797 by @dexhunter and PR #2014 by @simonbissonnette: SmearGate packed-document BOS masking issue and fix lineage. +- PR #1586 by @dexhunter, PR #1667 by @MarioPaerle, and PR #1729 by @romeerp: per-group `lrzip` / grouped serialization lineage. +- PR #1530 by @samacqua, PR #1886 by @renqianluo, PR #1923 by @jorge-asenjo, and PR #1344 by @Omrigotlieb: LQER/AWQ/asymmetric-rescale/optimizer lineage reflected in the inherited stack. +- PR #1145 by @AnirudhRahul and PR #1967 by @ndokutovich: online n-gram tilt / scoring overlay ideas present in the broader code lineage. +- PR #2027 by @H1cSuNtDr4C0n3S: JEPA-Lite competition-local precedent. BIJEPAX-lite is a separate Claude-designed bidirectional/hop-4 variant, but this PR should be credited as prior JEPA-style work in the competition. + +Our new contribution in this branch is the BIJEPAX-lite training-only auxiliary objective and the run package around it. The base GPT, tokenizer, compression, and PPM evaluator are inherited/adapted from the public lines above. + +## Claims to avoid unless sourced + +Do **not** write these as facts unless we find exact sources: + +- "BiJEPA cycle consistency has proven 4x better on chaotic systems." +- "Nobody in Parameter Golf has done backward JEPA prediction." +- "LLM-JEPA proves cosine similarity is strictly superior to MSE for all text." + +Safer wording: + +- "Inspired by JEPA and LLM-JEPA, we add a cosine embedding-space auxiliary objective." +- "We tested a bidirectional hidden-state prediction variant but submitted the lighter hop-4 version because it preserved training speed and artifact size." +- "The auxiliary predictor is training-only and is absent from the serialized artifact." + +## Suggested PR phrasing + +"This submission explores whether a JEPA-style representation prediction objective can act as a cheap training regularizer for the SP8192 CaseOps + PPM lane. The predictor heads are separate modules, optimized only during training, and are never serialized. Final scoring is still performed by the causal PPM sliding evaluator over the quantized model artifact." diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/STATIC_AUDIT_NOTES.md b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/STATIC_AUDIT_NOTES.md new file mode 100644 index 0000000000..2500075799 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/STATIC_AUDIT_NOTES.md @@ -0,0 +1,57 @@ +# Static Audit Notes + +## Network / external access + +Search terms checked in `train_gpt_v15_bijepax.py`: + +```text +requests, urllib, http, https, wget, curl, socket, huggingface, datasets, load_dataset +``` + +No network download path appears in the script. + +The only `subprocess.run` uses found are: + +- `lrzip` compression/decompression for the local model artifact. +- optional `pyminify` code minification fallback for the code wrapper. + +These are local tool invocations, not network calls. + +## Data access + +Training: + +- `DocumentPackingLoader` reads `fineweb_train_*.bin`. +- GPTQ calibration uses `ShuffledSequenceLoader`, also training data. +- BIJEPAX-lite loss reads only hidden states from training batches. + +Validation/evaluation: + +- `ValidationData` reads `fineweb_val_*.bin`. +- CaseOps reads `fineweb_val_bytes_*.bin` for byte accounting. +- PPM sliding evaluation reads validation tokens only during final evaluation. + +No code path was found where BIJEPAX-lite trains on validation data. + +## TTT paths + +The file contains TTT utilities, but this submission sets: + +```text +TTT_ENABLED=0 +``` + +So TTT paths are not used in the submitted BIJEPAX-lite score. + +## Serialization + +The final compressed model is built from `base_model.state_dict()`. + +BIJEPAX-lite predictor heads are held in `bijepax_module`, not as `base_model` children. Therefore they are not saved into `final_model.pt` or the int6 artifact. + +## Current concern list + +- The strict artifact cap headroom is very small. Seed 42 headroom was `2820` bytes. +- `lrzip` must exist in the evaluation environment. This matches the per-group compressor lane already used in the current stack. +- Claims about "BiJEPA 4x chaotic systems" remain unverified and should not go into the PR. +- If we submit a record PR, we need the seed 314 and 999 logs plus mean/std before finalizing `submission.json`. diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed314.txt b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed314.txt new file mode 100644 index 0000000000..76f964ad28 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed314.txt @@ -0,0 +1,5251 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.99 + bijepa_backward_weight: 1.0 + bijepa_enabled: False + bijepax_bwd_hops: (4,) + bijepax_cycle: False + bijepax_enabled: True + bijepax_end_frac: 0.8 + bijepax_fwd_hops: (4,) + bijepax_head_dim: 32 + bijepax_lr: 0.001 + bijepax_start_frac: 0.35 + bijepax_stride: 64 + bijepax_weight: 0.01 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 512 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + jepa_lite_enabled: False + jepa_lite_end_frac: 0.85 + jepa_lite_frequency: 8 + jepa_lite_hidden: 128 + jepa_lite_horizon: 4 + jepa_lite_lr: 0.01 + jepa_lite_start_frac: 0.35 + jepa_lite_stride: 16 + jepa_lite_weight: 0.03 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: True + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2500 + ppm_dump_inputs: False + ppm_h: 0.999 + ppm_l: 0.18 + ppm_mixer_enabled: True + ppm_order: 5 + ppm_t: 0.8 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209 + scalar_lr: 0.02 + seed: 314 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 250 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: False + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_local_lr_mult: 0.75 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mask: no_qv + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_q_lora: False + ttt_train_window_tokens: 0 + ttt_v_lora: False + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_bytes_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Source code: +==================================================================================================== +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +class _BiJEPAXHead(nn.Module): + def __init__(self, d_model: int, head_dim: int): + super().__init__() + self.w1 = nn.Linear(d_model, head_dim, bias=False) + self.ln = nn.LayerNorm(head_dim) + self.w2 = nn.Linear(head_dim, d_model, bias=False) + nn.init.orthogonal_(self.w1.weight) + nn.init.zeros_(self.w2.weight) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.gelu(self.ln(self.w1(x.float())))) + + +class MultiDirectionalBiJEPAX(nn.Module): + """Training-only multi-hop bidirectional/cycle JEPA regularizer.""" + + def __init__( + self, + d_model: int, + fwd_hops: tuple[int, ...] = (1, 4, 16), + bwd_hops: tuple[int, ...] = (1, 4), + cycle: bool = True, + head_dim: int = 64, + ): + super().__init__() + self.fwd_hops = list(fwd_hops) + self.bwd_hops = list(bwd_hops) + self.cycle = cycle + self.fwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) + self.bwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in bwd_hops] + ) + self.cyc_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) if cycle else nn.ModuleList() + + @staticmethod + def _cos_loss(pred: Tensor, tgt: Tensor) -> Tensor: + pred = F.normalize(pred, dim=-1) + tgt = F.normalize(tgt.float().detach(), dim=-1) + return (1.0 - (pred * tgt).sum(dim=-1)).mean() + + def forward(self, hidden: Tensor, weight: float, stride: int) -> Tensor: + if weight <= 0.0 or hidden.size(1) < 2: + return hidden.new_zeros(()) + stride = max(1, int(stride)) + seq_len = hidden.size(1) + losses: list[Tensor] = [] + for i, hop in enumerate(self.fwd_hops): + if seq_len <= hop: + continue + src = hidden[:, : seq_len - hop : stride, :] + tgt = hidden[:, hop : hop + src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + src = src[:, :n, :] + tgt = tgt[:, :n, :] + losses.append(self._cos_loss(self.fwd_heads[i](src), tgt)) + if self.cycle and i < len(self.cyc_heads): + losses.append(self._cos_loss(self.cyc_heads[i](tgt.detach()), src)) + for i, hop in enumerate(self.bwd_hops): + if seq_len <= hop: + continue + src = hidden[:, hop::stride, :] + tgt = hidden[:, : src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + losses.append(self._cos_loss(self.bwd_heads[i](src[:, :n, :]), tgt[:, :n, :])) + if not losses: + return hidden.new_zeros(()) + return weight * (sum(losses) / len(losses)) + + +def bijepax_weight_at(frac: float, start: float, end: float, peak: float) -> float: + if frac < start or frac >= end: + return 0.0 + span = max(end - start, 1e-9) + ramp = min(span * 0.15, 0.08) + if frac < start + ramp: + return peak * (frac - start) / max(ramp, 1e-9) + if frac >= end - ramp: + return peak * (end - frac) / max(ramp, 1e-9) + return peak + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.85)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + # v13 is the sidecar-aware PPM lane. These defaults match the under-cap + # H100 package runs instead of the older TTT-first v12 defaults. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.1)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.99)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 512)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 80)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_local_lr_mult = float(os.environ.get("TTT_LOCAL_LR_MULT", 0.75)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_train_window_tokens = int(os.environ.get("TTT_TRAIN_WINDOW_TOKENS", 0)) + # V19: PR #1886 (renqianluo) + sunnypatneedi research log 2026-04-28 found that + # the Triton fused-CE kernel's fp32-accumulation interacts with warm-start LoRA-A + # to destabilize seeds 314/1337 at TTT_WEIGHT_DECAY=1.0. Raising the default to + # 2.0 prevents seed collapse without measurably moving stable seeds. + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 0.5)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.99)) + ttt_mask = os.environ.get("TTT_MASK", "no_qv").strip().lower() + _ttt_q_default = "1" + _ttt_v_default = "1" + if ttt_mask in ("", "all", "baseline_all"): + pass + elif ttt_mask == "no_q": + _ttt_q_default = "0" + elif ttt_mask == "no_v": + _ttt_v_default = "0" + elif ttt_mask == "no_qv": + _ttt_q_default = "0" + _ttt_v_default = "0" + else: + raise ValueError(f"Unsupported TTT_MASK={ttt_mask!r}") + ttt_q_lora = bool(int(os.environ.get("TTT_Q_LORA", _ttt_q_default))) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_v_lora = bool(int(os.environ.get("TTT_V_LORA", _ttt_v_default))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "pergroup") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 0.5)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2500)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 3)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 7)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 14.0)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 11.5)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "1"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "1"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "1"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 0.5)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + lqer_scope = os.environ.get("LQER_SCOPE", "all") + lqer_gain_select = bool(int(os.environ.get("LQER_GAIN_SELECT", "0"))) + awq_lite_enabled = bool(int(os.environ.get("AWQ_LITE_ENABLED", "1"))) + awq_lite_bits = int(os.environ.get("AWQ_LITE_BITS", "8")) + awq_lite_group_top_k = int(os.environ.get("AWQ_LITE_GROUP_TOP_K", "1")) + awq_lite_group_size = int(os.environ.get("AWQ_LITE_GROUP_SIZE", "64")) + # PR #1145/#1967 online n-gram tilt. This is a causal scoring overlay: + # prefix-only token/within-word/word experts propose one hint token, then + # the per-token NLL is adjusted with closed-form softmax renormalization. + ngram_tilt_enabled = bool(int(os.environ.get("NGRAM_TILT_ENABLED", "1"))) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + within_tau = float(os.environ.get("WITHIN_TAU", "0.450")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.750")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "0.650")) + word_boost = float(os.environ.get("WORD_BOOST", "0.750")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + ngram_hint_precompute_outside = bool(int(os.environ.get("NGRAM_HINT_PRECOMPUTE_OUTSIDE", "1"))) + ppm_mixer_enabled = bool(int(os.environ.get("PPM_MIXER_ENABLED", "1"))) + ppm_order = int(os.environ.get("PPM_ORDER", "5")) + ppm_h = float(os.environ.get("PPM_H", "0.995")) + ppm_l = float(os.environ.get("PPM_L", "0.20")) + ppm_t = float(os.environ.get("PPM_T", "0.80")) + ppm_dump_inputs = bool(int(os.environ.get("PPM_DUMP_INPUTS", "0"))) + # JEPA-Lite training regularizer from PR #2027. The predictor is never + # serialized; it only nudges hidden states during training. + jepa_lite_enabled = bool(int(os.environ.get("JEPA_LITE_ENABLED", "0"))) + jepa_lite_weight = float(os.environ.get("JEPA_LITE_WEIGHT", "0.03")) + jepa_lite_hidden = int(os.environ.get("JEPA_LITE_HIDDEN", "128")) + jepa_lite_horizon = int(os.environ.get("JEPA_LITE_HORIZON", "4")) + jepa_lite_stride = int(os.environ.get("JEPA_LITE_STRIDE", "16")) + jepa_lite_frequency = int(os.environ.get("JEPA_LITE_FREQUENCY", "8")) + jepa_lite_start_frac = float(os.environ.get("JEPA_LITE_START_FRAC", "0.35")) + jepa_lite_end_frac = float(os.environ.get("JEPA_LITE_END_FRAC", "0.85")) + jepa_lite_lr = float(os.environ.get("JEPA_LITE_LR", "0.01")) + bijepa_enabled = bool(int(os.environ.get("BIJEPA_ENABLED", "0"))) + bijepa_backward_weight = float(os.environ.get("BIJEPA_BACKWARD_WEIGHT", "1.0")) + bijepax_enabled = bool(int(os.environ.get("BIJEPAX_ENABLED", "0"))) + bijepax_weight = float(os.environ.get("BIJEPAX_WEIGHT", "0.04")) + bijepax_start_frac = float(os.environ.get("BIJEPAX_START_FRAC", "0.05")) + bijepax_end_frac = float(os.environ.get("BIJEPAX_END_FRAC", "0.80")) + bijepax_fwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_FWD_HOPS", "1,4,16").split(",") if x + ) + bijepax_bwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_BWD_HOPS", "1,4").split(",") if x + ) + bijepax_cycle = bool(int(os.environ.get("BIJEPAX_CYCLE", "1"))) + bijepax_head_dim = int(os.environ.get("BIJEPAX_HEAD_DIM", "64")) + bijepax_stride = int(os.environ.get("BIJEPAX_STRIDE", "8")) + bijepax_lr = float(os.environ.get("BIJEPAX_LR", "3e-4")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "1"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + if self.caseops_enabled: + self.base_bytes_lut = None + self.has_leading_space_lut = None + self.is_boundary_token_lut = None + else: + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._prefetch_queue = [] + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + targets = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + inputs.copy_(buf[:-1]) + targets.copy_(buf[1:]) + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + while len(self._prefetch_queue) < 2: + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + inputs, targets, cu_seqlens, max_seqlen = self._prefetch_queue.pop(0).result() + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def fused_log_softmax_dual_gather_kernel( + logits_ptr, + target_ids_ptr, + hint_ids_ptr, + log_p_y_out_ptr, + log_q_h_out_ptr, + n_rows, + n_cols, + block_cols: tl.constexpr, +): + row_idx = tl.program_id(0) + if row_idx >= n_rows: + return + target = tl.load(target_ids_ptr + row_idx) + hint = tl.load(hint_ids_ptr + row_idx) + row_offset = row_idx * n_cols + target_logit = tl.load(logits_ptr + row_offset + target).to(tl.float32) + hint_logit = tl.load(logits_ptr + row_offset + hint).to(tl.float32) + max_val = -float("inf") + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=-float("inf") + ).to(tl.float32) + max_val = tl.maximum(max_val, tl.max(vals, axis=0)) + sum_exp = tl.zeros((), dtype=tl.float32) + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=0.0 + ).to(tl.float32) + sum_exp += tl.sum(tl.where(mask, tl.exp(vals - max_val), 0.0), axis=0) + lse = max_val + tl.log(sum_exp) + tl.store(log_p_y_out_ptr + row_idx, target_logit - lse) + tl.store(log_q_h_out_ptr + row_idx, hint_logit - lse) + + +def fused_log_softmax_dual_gather(logits, target_ids, hint_ids): + bsz, seqlen, vocab = logits.shape + n_rows = bsz * seqlen + logits_flat = logits.reshape(n_rows, vocab).contiguous() + target_flat = target_ids.reshape(n_rows).contiguous() + hint_flat = hint_ids.reshape(n_rows).contiguous() + log_p_y_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + log_q_h_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + fused_log_softmax_dual_gather_kernel[(n_rows,)]( + logits_flat, + target_flat, + hint_flat, + log_p_y_out, + log_q_h_out, + n_rows, + vocab, + block_cols=1024, + num_warps=8, + ) + return log_p_y_out.reshape(bsz, seqlen), log_q_h_out.reshape(bsz, seqlen) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.3 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.3 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.3 * c0) + aux1 = tl.where(c1 > 0, c1, 0.3 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 256, 128, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.3).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.fused_ce_enabled = bool(h.fused_ce_enabled) + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + # V19: Asymmetric Logit Rescale (PR #1923 jorge-asenjo). + # Two learnable softcap scales applied on the EVAL path (forward_logits + + # forward_ttt). Init to logit_softcap so the layer is identity at step 0. + # Train path keeps the single fused softcap to preserve PR #1855 numerics. + self.asym_logit_enabled = bool(int(os.environ.get("ASYM_LOGIT_RESCALE", "1"))) + if self.asym_logit_enabled: + self.softcap_pos = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.softcap_neg = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.jepa_horizon = h.jepa_lite_horizon + self.jepa_stride = h.jepa_lite_stride + self.register_buffer( + "jepa_weight_current", + torch.zeros((), dtype=torch.float32), + persistent=False, + ) + self.jepa_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled + else None + ) + self.bijepa_enabled = bool(h.bijepa_enabled) + self.bijepa_backward_weight = float(h.bijepa_backward_weight) + self.jepa_bwd_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled and h.bijepa_enabled + else None + ) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self.tok_emb(input_ids) + # SmearGate (PR #1667). lam=0 + W=0 -> identity at init. + # Cross-doc leak fix: zero the prev-token smear at any position whose current token + # is BOS, so the BOS embedding starting doc N+1 in a packed stream is not + # contaminated by doc N's last token (audited issue on PR#1797 base). + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + return F.linear(hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def _apply_asym_softcap(self, logits): + # V19: Asymmetric softcap (PR #1923). Splits the logit_softcap scalar into + # learnable positive/negative branches. Score-first preserved: still a + # bounded, normalized post-projection nonlinearity feeding a standard + # softmax over the full vocab. + sp = self.softcap_pos.to(logits.dtype) + sn = self.softcap_neg.to(logits.dtype) + return torch.where(logits > 0, sp * torch.tanh(logits / sp), sn * torch.tanh(logits / sn)) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + if self.asym_logit_enabled: + return self._apply_asym_softcap(logits_proj) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def jepa_aux_loss(self, features): + if self.jepa_predictor is None: + return features.new_zeros(()) + horizon = max(1, int(self.jepa_horizon)) + stride = max(1, int(self.jepa_stride)) + if features.size(1) <= horizon: + return features.new_zeros(()) + src = features[:, :-horizon:stride, :] + tgt = features[:, horizon::stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + return features.new_zeros(()) + src = F.normalize(src[:, :n, :].float(), dim=-1) + tgt = F.normalize(tgt[:, :n, :].detach().float(), dim=-1) + pred = F.normalize(self.jepa_predictor(src), dim=-1) + loss = (1.0 - (pred * tgt).sum(dim=-1)).mean() + if self.jepa_bwd_predictor is not None: + bwd_src = F.normalize(features[:, horizon::stride, :][:, :n, :].float(), dim=-1) + bwd_tgt = F.normalize( + features[:, :-horizon:stride, :][:, :n, :].detach().float(), + dim=-1, + ) + bwd_pred = F.normalize(self.jepa_bwd_predictor(bwd_src), dim=-1) + bwd_loss = (1.0 - (bwd_pred * bwd_tgt).sum(dim=-1)).mean() + loss = loss + self.bijepa_backward_weight * bwd_loss + return loss * self.jepa_weight_current.to( + device=features.device, + dtype=features.dtype, + ) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss + + def forward_with_hidden(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden( + input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss, hidden + + def forward_ttt(self, input_ids, target_ids, lora, hint_ids=None): + x = self.tok_emb(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + # Cross-doc leak fix: see _forward_hidden comment. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + # V19: same asymmetric softcap on the TTT eval path. + if self.asym_logit_enabled: + logits = self._apply_asym_softcap(logits) + else: + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + if hint_ids is None: + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + if not logits.requires_grad: + log_p_y, log_q_h = fused_log_softmax_dual_gather( + logits, target_ids, hint_ids.clamp(min=0) + ) + return -log_p_y, log_q_h + ls = F.log_softmax(logits.float(), dim=-1) + log_p_y = ls.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) + log_q_h = ls.gather(-1, hint_ids.clamp(min=0).unsqueeze(-1)).squeeze(-1) + return -log_p_y, log_q_h + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__( + self, bsz, model, rank, + q_lora=True, k_lora=True, v_lora=True, mlp_lora=True, o_lora=True, + ): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if q_lora + else None + ) + self.v_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if v_lora + else None + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad) + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update, alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + jepa_params = ( + list(base_model.jepa_predictor.parameters()) + if getattr(base_model, "jepa_predictor", None) is not None + else [] + ) + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + jepa_params.extend(base_model.jepa_bwd_predictor.parameters()) + if jepa_params: + self.optimizer_jepa = torch.optim.AdamW( + [ + { + "params": jepa_params, + "lr": h.jepa_lite_lr, + "base_lr": h.jepa_lite_lr, + } + ], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=0.0, + fused=True, + ) + else: + self.optimizer_jepa = None + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + if self.optimizer_jepa is not None: + self.optimizers.append(self.optimizer_jepa) + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_params.extend(jepa_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + self._aux_stream = torch.cuda.Stream() + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self._aux_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._aux_stream): + self.optimizer_tok.step() + self.optimizer_scalar.step() + if self.optimizer_jepa is not None: + self.optimizer_jepa.step() + self.optimizer_muon.step() + torch.cuda.current_stream().wait_stream(self._aux_stream) + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + act_sumsq = {} + act_counts = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + x_sq = x.square().sum(dim=0) + x_count = x.shape[0] + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x_sq + act_counts[name] += x_count + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + y.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += y.square().sum(dim=0) + act_counts[name] += y.shape[0] + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + h_act.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += h_act.square().sum(dim=0) + act_counts[name] += h_act.shape[0] + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + act_stats = {} + for name, sumsq in act_sumsq.items(): + count = max(act_counts.get(name, 0), 1) + act_stats[name] = (sumsq / count).sqrt().cpu() + return hessians, act_stats + + +def gptq_quantize_weight( + w, + H, + clip_sigmas=3.0, + clip_range=63, + block_size=128, + protect_groups=None, + group_size=None, + protect_clip_range=None, +): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + H_flip = torch.flip(H, dims=(0, 1)) + L_flip = torch.linalg.cholesky(H_flip) + U = torch.flip(L_flip, dims=(0, 1)) + eye = torch.eye(H.shape[0], device=H.device, dtype=H.dtype) + Hinv = torch.linalg.solve_triangular(U, eye, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + protect_meta = None + protect_mask_perm = None + s_hi = None + sf_hi = None + if ( + protect_groups + and group_size is not None + and protect_clip_range is not None + and protect_clip_range > clip_range + ): + protect_mask = torch.zeros(cols, dtype=torch.bool) + starts = [] + for (start, end) in protect_groups: + if start < 0 or end > cols or end <= start: + continue + protect_mask[start:end] = True + starts.append(start) + if starts: + protect_mask_perm = protect_mask[perm] + s_hi = (clip_sigmas * row_std / protect_clip_range).clamp_min(1e-10).to( + torch.float16 + ) + sf_hi = s_hi.float() + protect_meta = { + "starts": torch.tensor(starts, dtype=torch.int16), + "size": int(group_size), + "s_hi": s_hi, + } + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + if protect_mask_perm is not None and bool(protect_mask_perm[i1 + j]): + q_col = torch.clamp( + torch.round(w_col / sf_hi), + -protect_clip_range, + protect_clip_range, + ) + w_recon = q_col.float() * sf_hi + else: + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + w_recon = q_col.float() * sf + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - w_recon) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s, protect_meta + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def _lqer_fit_quantized(E, h): + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + if r <= 0: + return None + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + A_hat = qA.float() * float(sA) + g_sz = qB.numel() // sB.numel() + B_hat = (qB.reshape(-1, g_sz).float() * sB.float().view(-1, 1)).reshape( + qB.shape + ) + return { + "kind": "asym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + A_hat = qA.float() * sA.float().view(-1, 1) + B_hat = qB.float() * sB.float().view(-1, 1) + return { + "kind": "sym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + + +def _awq_lite_group_candidates(w, act_rms, group_size): + cols = w.shape[1] + n_groups = cols // group_size + if n_groups <= 0: + return [] + weight_score = w.float().abs().mean(dim=0) + saliency = act_rms.float() * weight_score + cands = [] + for gi in range(n_groups): + start = gi * group_size + end = start + group_size + score = float(saliency[start:end].sum()) + cands.append((score, start, end)) + return cands + + +def gptq_mixed_quantize(state_dict, hessians, act_stats, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + awq_on = bool(getattr(h, "awq_lite_enabled", False)) + lqer_cands = {} + awq_selected = collections.defaultdict(list) + if awq_on: + awq_cands = [] + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if t.is_floating_point() and t.numel() > 65536 and name in act_stats: + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + if bits < h.awq_lite_bits: + for score, start, end in _awq_lite_group_candidates( + t, act_stats[name], h.awq_lite_group_size + ): + awq_cands.append((score, name, start, end)) + awq_cands.sort(key=lambda x: -x[0]) + for (_score, name, start, end) in awq_cands[: h.awq_lite_group_top_k]: + awq_selected[name].append((start, end)) + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + q, s, protect_meta = gptq_quantize_weight( + t, + hessians[name], + clip_sigmas=cs, + clip_range=clip_range, + protect_groups=awq_selected.get(name), + group_size=h.awq_lite_group_size if name in awq_selected else None, + protect_clip_range=(2 ** (h.awq_lite_bits - 1) - 1) + if name in awq_selected + else None, + ) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + W_q = q.float() * s.float().view(-1, 1) + if protect_meta is not None: + result[name + ".awqg_start"] = protect_meta["starts"] + result[name + ".awqg_s_hi"] = protect_meta["s_hi"] + result[name + ".awqg_size"] = torch.tensor( + protect_meta["size"], dtype=torch.int16 + ) + meta[name] = meta[name] + f"+awqgrpint{h.awq_lite_bits}" + gsz = protect_meta["size"] + for start in protect_meta["starts"].tolist(): + W_q[:, start : start + gsz] = ( + q[:, start : start + gsz].float() + * protect_meta["s_hi"].float().view(-1, 1) + ) + if lqer_on: + # LQER is fit on top of the fully realized GPTQ base, which already + # includes any higher-precision AWQ-protected groups. + scope = str(getattr(h, "lqer_scope", "all")).lower() + scope_ok = ( + scope == "all" + or (scope == "mlp" and ".mlp." in name) + or (scope == "attn" and ".attn." in name) + or (scope == "embed" and "tok_emb" in name) + ) + if scope_ok: + E = t.float() - W_q + err_norm = float(E.norm()) + if err_norm > 0: + lqer_cands[name] = (E, err_norm) + if lqer_on and lqer_cands: + if bool(getattr(h, "lqer_gain_select", False)): + scored = [] + for (name, (E, base_err)) in lqer_cands.items(): + fit = _lqer_fit_quantized(E, h) + if fit is None: + continue + new_err = float((E - fit["delta"]).norm()) + gain = base_err - new_err + if gain > 0: + scored.append((gain, name, fit)) + scored.sort(key=lambda x: -x[0]) + for (_gain, name, fit) in scored[: h.lqer_top_k]: + if fit["kind"] == "asym": + result[name + ".lqA_a"] = fit["qA"] + result[name + ".lqAs_a"] = fit["sA"] + result[name + ".lqB_a"] = fit["qB"] + result[name + ".lqBs_a"] = fit["sB"] + meta[name] = meta[name] + "+lqer_asym" + else: + result[name + ".lqA"] = fit["qA"] + result[name + ".lqAs"] = fit["sA"] + result[name + ".lqB"] = fit["qB"] + result[name + ".lqBs"] = fit["sB"] + meta[name] = meta[name] + "+lqer" + else: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "awqgrpint" in info: + starts = result[name + ".awqg_start"].tolist() + s_hi = result[name + ".awqg_s_hi"].float() + gsz = int(result[name + ".awqg_size"].item()) + for start in starts: + W[:, start : start + gsz] = ( + q[:, start : start + gsz].float() * s_hi.view(-1, 1) + ) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +# ── Per-group lrzip compression (ported from PR#1586 via PR#1667/1729) ──────── + +_GROUP_ORDER = [ + "_tok_emb.weight.q", + "attn.c_k.weight.q", "attn.c_q.weight.q", + "attn.c_v.weight.q", "attn.proj.weight.q", + "mlp.fc.weight.q", "mlp.proj.weight.q", +] +_SIMSORT_KEYS = {"_tok_emb.weight.q", "attn.c_q.weight.q", "mlp.fc.weight.q"} +_PACK_MAGIC = b"PGRP" + + +def _similarity_sort_l1(matrix): + import numpy as _np + n = matrix.shape[0] + used = _np.zeros(n, dtype=bool) + order = [0] + used[0] = True + cur = matrix[0].astype(_np.float32) + for _ in range(n - 1): + dists = _np.sum(_np.abs(matrix[~used].astype(_np.float32) - cur), axis=1) + unused = _np.where(~used)[0] + best = unused[_np.argmin(dists)] + order.append(best) + used[best] = True + cur = matrix[best].astype(_np.float32) + return _np.array(order, dtype=_np.uint16) + + +def _lrzip_compress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.bin") + out = f"{inp}.lrz" + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-z", "-L", "9", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _lrzip_decompress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.lrz") + out = os.path.join(tmpdir, f"{label}.bin") + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-d", "-f", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _pack_streams(streams): + import struct + n = len(streams) + hdr = _PACK_MAGIC + struct.pack("", p) + if m: + return bytes([int(m.group(1), 16)]) + return (" " + p[1:]).encode() if p.startswith("▁") else p.encode() + + +def _ppm_mixture_bpb(tgt_np, lp_np, sp, O=4, H=0.9, L_=0.05, T=0.9, token_byte_lens_np=None): + V = sp.vocab_size() + piece_bytes = [None] * V + piece_lens = np.zeros(V, dtype=np.int32) + for i in range(V): + b = _ppm_piece_bytes(sp, i) + piece_bytes[i] = b + piece_lens[i] = len(b) + if token_byte_lens_np is None: + per_tok_len = piece_lens[tgt_np] + bs = b''.join(piece_bytes[int(t)] for t in tgt_np) + kept_lp = lp_np + else: + chunks = [] + kept_lp_parts = [] + lens_parts = [] + for t, lp, side_len in zip(tgt_np, lp_np, token_byte_lens_np): + side_len = int(side_len) + if side_len <= 0: + continue + b = piece_bytes[int(t)] + if not b: + continue + if len(b) > side_len: + b = b[:side_len] + elif len(b) < side_len: + b = b + b[-1:] * (side_len - len(b)) + chunks.append(b) + kept_lp_parts.append(float(lp)) + lens_parts.append(side_len) + if not chunks: + return float("inf") + bs = b''.join(chunks) + per_tok_len = np.asarray(lens_parts, dtype=np.int32) + kept_lp = np.asarray(kept_lp_parts, dtype=np.float64) + N = len(bs) + rep_lp = np.repeat(kept_lp.astype(np.float64), per_tok_len) + rep_len = np.repeat(per_tok_len.astype(np.float64), per_tok_len) + nlp = np.where(rep_len > 0, rep_lp / rep_len, 0.0) + tabs = [dict() for _ in range(O + 1)] + plp = np.empty(N, dtype=np.float64) + cf = np.empty(N, dtype=np.float64) + LN256 = math.log(1 / 256) + log_ = math.log + h_ctx = b'' + for i in range(N): + x = bs[i] + if i == 0: + plp[i] = LN256 + cf[i] = 1 / 256 + else: + esc = 1.0 + pf = 0.0 + cf_mx = 0 + cf_tot = 256 + cf_seen = False + lim = O if i > O else i + for o in range(lim, -1, -1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + continue + if not cf_seen: + cf_mx = e[1] + cf_tot = e[0] + cf_seen = True + tot = e[0] + d = e[2] + c = d.get(x, 0) + if c > 0: + pf = esc * (2 * c - 1) / (2 * tot) + break + esc *= len(d) / (2 * tot) + else: + pf = esc / 256 + if pf < 1e-20: + pf = 1e-20 + plp[i] = log_(pf) + cf[i] = (cf_mx / cf_tot) if cf_seen else 1 / 256 + for o in range(O + 1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + tabs[o][k] = [1, 1, {x: 1}] + else: + e[0] += 1 + d = e[2] + cnt = d.get(x, 0) + 1 + d[x] = cnt + if cnt > e[1]: + e[1] = cnt + h_ctx = (h_ctx + bytes([x]))[-O:] + nn_prob = np.exp(nlp) + ppm_prob = np.exp(plp) + + def _mix_bpb(Hv, Lv, Tv): + lam_v = np.where(cf > Tv, Lv, Hv) + pm_v = lam_v * nn_prob + (1 - lam_v) * ppm_prob + return float(-np.log2(np.maximum(pm_v, 1e-300)).sum() / N) + + default_bpb = _mix_bpb(H, L_, T) + if os.environ.get("PPM_SWEEP_GRID", "0") == "1": + hs = [float(x) for x in os.environ.get("PPM_SWEEP_HS", str(H)).split(",") if x.strip()] + ls = [float(x) for x in os.environ.get("PPM_SWEEP_LS", str(L_)).split(",") if x.strip()] + ts = [float(x) for x in os.environ.get("PPM_SWEEP_TS", str(T)).split(",") if x.strip()] + combo_count = len(hs) * len(ls) * len(ts) + max_combos = int(os.environ.get("PPM_SWEEP_MAX_COMBOS", "256")) + if combo_count > max_combos and os.environ.get("PPM_SWEEP_ALLOW_SLOW", "0") != "1": + log( + f"ppm_sweep skipped: combos={combo_count} max={max_combos}; " + "dump inputs and replay offline, or set PPM_SWEEP_ALLOW_SLOW=1" + ) + return default_bpb + best = (default_bpb, H, L_, T) + for Hv in hs: + for Lv in ls: + for Tv in ts: + bpb = _mix_bpb(Hv, Lv, Tv) + if bpb < best[0]: + best = (bpb, Hv, Lv, Tv) + log( + f"ppm_sweep best_bpb:{best[0]:.8f} H={best[1]} L={best[2]} T={best[3]} " + f"default_bpb:{default_bpb:.8f}" + ) + if os.environ.get("PPM_SWEEP_APPLY", "0") == "1": + return best[0] + return default_bpb + + +def eval_val_ppm_sliding(h, device, val_data, model, batch_seqs=32): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + model.eval() + seq_len = h.eval_seq_len + stride = h.eval_stride + context_size = seq_len - stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + tga_local = [] + lpa_local = [] + bla_local = [] + fwd_fn = model.module.forward_logits if hasattr(model, 'module') else model.forward_logits + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = fwd_fn(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction='none', + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + if val_data.val_bytes is not None: + tb = val_data.val_bytes[ws + s + 1: ws + wlen + 1].to(device=device, dtype=torch.float64) + else: + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + tga_local.append(tgt.cpu().to(torch.int64)) + lpa_local.append((-scored_nll).cpu().to(torch.float64)) + bla_local.append(tb.cpu().to(torch.int32)) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss, val_bpb = _loss_bpb(loss_sum, token_count, byte_count) + if h.ppm_mixer_enabled: + tga_local_cat = torch.cat(tga_local) if tga_local else torch.zeros(0, dtype=torch.int64) + lpa_local_cat = torch.cat(lpa_local) if lpa_local else torch.zeros(0, dtype=torch.float64) + bla_local_cat = torch.cat(bla_local) if bla_local else torch.zeros(0, dtype=torch.int32) + if dist.is_available() and dist.is_initialized(): + local_size = torch.tensor([tga_local_cat.numel()], dtype=torch.int64, device=device) + sizes = [torch.zeros(1, dtype=torch.int64, device=device) for _ in range(h.world_size)] + dist.all_gather(sizes, local_size) + sizes_list = [int(s.item()) for s in sizes] + max_size = max(sizes_list) if sizes_list else 0 + tga_pad = torch.zeros(max_size, dtype=torch.int64, device=device) + lpa_pad = torch.zeros(max_size, dtype=torch.float64, device=device) + bla_pad = torch.zeros(max_size, dtype=torch.int32, device=device) + tga_pad[:tga_local_cat.numel()] = tga_local_cat.to(device) + lpa_pad[:lpa_local_cat.numel()] = lpa_local_cat.to(device) + bla_pad[:bla_local_cat.numel()] = bla_local_cat.to(device) + if h.rank == 0: + gather_t = [torch.zeros(max_size, dtype=torch.int64, device=device) for _ in range(h.world_size)] + gather_l = [torch.zeros(max_size, dtype=torch.float64, device=device) for _ in range(h.world_size)] + gather_b = [torch.zeros(max_size, dtype=torch.int32, device=device) for _ in range(h.world_size)] + else: + gather_t = None + gather_l = None + gather_b = None + dist.gather(tga_pad, gather_t, dst=0) + dist.gather(lpa_pad, gather_l, dst=0) + dist.gather(bla_pad, gather_b, dst=0) + if h.rank == 0: + tga_full = torch.cat([gather_t[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + lpa_full = torch.cat([gather_l[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + bla_full = torch.cat([gather_b[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_full.astype(np.int64), + logp=lpa_full.astype(np.float64), + byte_lens=bla_full.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_full, lpa_full, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_full) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_full.size} N_sidecar_bytes={int(bla_full.sum())}') + val_bpb = mixer_bpb + else: + tga_np = tga_local_cat.numpy() + lpa_np = lpa_local_cat.numpy() + bla_np = bla_local_cat.numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_np.astype(np.int64), + logp=lpa_np.astype(np.float64), + byte_lens=bla_np.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_np, lpa_np, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_np) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_np.size} N_sidecar_bytes={int(bla_np.sum())}') + val_bpb = mixer_bpb + model.train() + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( + device=device, dtype=torch.int32, non_blocking=True + ) + val_byte_count += sidecar_slice.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def jepa_weight_at(h, step, frac): + if not h.jepa_lite_enabled: + return 0.0 + freq = max(1, h.jepa_lite_frequency) + if step % freq != 0 or frac < h.jepa_lite_start_frac: + return 0.0 + weight = h.jepa_lite_weight + if frac > h.jepa_lite_end_frac: + weight *= max( + 0.0, + (1.0 - frac) / max(1e-09, 1.0 - h.jepa_lite_end_frac), + ) + return weight + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + if os.environ.get("VAL_DOC_PREFIX_ONLY", "0") == "1": + return doc_entries[:sample_n] + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def _compute_ngram_hints_for_val(h, val_data, log0=print): + if not getattr(h, "ngram_tilt_enabled", False): + return None + from online_ngram_tilt import build_hints_for_targets + + all_tokens = val_data.val_tokens + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + max_targets = int(os.environ.get("NGRAM_HINT_MAX_TARGETS", "0")) + target_count = targets_np_all.shape[0] + if max_targets > 0: + targets_np = targets_np_all[: min(max_targets, target_count)] + else: + targets_np = targets_np_all + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log0, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + if hint_global.numel() < target_count: + padded_hint = torch.zeros(target_count, dtype=torch.int64) + padded_gate = torch.zeros(target_count, dtype=torch.bool) + padded_boost = torch.zeros(target_count, dtype=torch.float32) + padded_hint[: hint_global.numel()] = hint_global + padded_gate[: gate_global.numel()] = gate_global + padded_boost[: boost_global.numel()] = boost_global + hint_global, gate_global, boost_global = padded_hint, padded_gate, padded_boost + log0( + f"ngram_tilt:precompute_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={hint_global.numel()} computed_targets={targets_np.shape[0]}" + ) + return hint_global, gate_global, boost_global + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, precomputed_hints=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + ngram_hint_global = None + ngram_gate_global = None + ngram_boost_global = None + if precomputed_hints is not None: + ngram_hint_global, ngram_gate_global, ngram_boost_global = precomputed_hints + log( + "ngram_tilt:using_precomputed_hints " + f"total_targets={ngram_hint_global.numel()}" + ) + elif getattr(h, "ngram_tilt_enabled", False): + ngram_hint_global, ngram_gate_global, ngram_boost_global = _compute_ngram_hints_for_val( + h, val_data, log0=log + ) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + local_lr = h.ttt_lora_lr * h.ttt_local_lr_mult + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=local_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=local_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + hint_ids_gpu = None + gate_mask_gpu = None + boost_gpu = None + if ngram_hint_global is not None: + hint_idx_cpu = ( + tok_starts.unsqueeze(1) + col_idx[:context_size].unsqueeze(0) + ).clamp_(min=0, max=ngram_hint_global.numel() - 1) + hint_ids_gpu = ngram_hint_global[hint_idx_cpu].to( + device=device, dtype=torch.int64, non_blocking=True + ) + gate_mask_gpu = ngram_gate_global[hint_idx_cpu].to( + device=device, non_blocking=True + ) + boost_gpu = ngram_boost_global[hint_idx_cpu].to( + device=device, dtype=torch.float32, non_blocking=True + ) + hint_ids_gpu = torch.where(valid, hint_ids_gpu, torch.zeros_like(hint_ids_gpu)) + gate_mask_gpu = gate_mask_gpu & valid + log_q_hint = None + if hint_ids_gpu is not None: + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss, log_q_hint = forward_ttt_train( + x, y, lora=cur_lora, hint_ids=hint_ids_gpu + ) + else: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + if hint_ids_gpu is not None and log_q_hint is not None: + from online_ngram_tilt import apply_tilt_to_ptl_torch_fast + + scored_loss = apply_tilt_to_ptl_torch_fast( + ptl=per_tok_loss, + log_q_hint=log_q_hint, + target_ids=y, + hint_ids=hint_ids_gpu, + gate_mask=gate_mask_gpu, + boost=boost_gpu, + ) + else: + scored_loss = per_tok_loss + with torch.no_grad(): + _accumulate_bpb( + scored_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if scored_loss is not per_tok_loss: + del scored_loss + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + train_x, train_y = x, y + train_chunk_offset = chunk_offset + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > 0 and context_size > max(train_window, chunk_size): + train_window = max(train_window, chunk_size) + train_end = min(context_size, chunk_offset + chunk_size) + train_start = max(0, train_end - train_window) + train_x = x[:, train_start:train_end].contiguous() + train_y = y[:, train_start:train_end].contiguous() + train_chunk_offset = chunk_offset - train_start + for gi in range(h.ttt_grad_steps): + if hint_ids_gpu is not None or gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + train_per_tok_loss = forward_ttt_train( + train_x, train_y, lora=cur_lora + ) + else: + train_per_tok_loss = per_tok_loss + per_doc = train_per_tok_loss[ + :, train_chunk_offset : train_chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + if train_per_tok_loss is not per_tok_loss: + del train_per_tok_loss + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compile_enabled = os.environ.get("DISABLE_COMPILE", "0") != "1" + if compile_enabled: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + else: + log("compile:disabled_by_env") + compiled_model = base_model + compiled_forward_logits = base_model.forward_logits + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + if h.jepa_lite_enabled: + log( + "jepa_lite_config: " + f"weight={h.jepa_lite_weight} hidden={h.jepa_lite_hidden} " + f"horizon={h.jepa_lite_horizon} stride={h.jepa_lite_stride} " + f"frequency={h.jepa_lite_frequency} start_frac={h.jepa_lite_start_frac} " + f"end_frac={h.jepa_lite_end_frac} lr={h.jepa_lite_lr}" + ) + if h.bijepa_enabled: + log( + "bijepa_config: " + f"backward_weight={h.bijepa_backward_weight} " + "dual_predictor=true training_only=true" + ) + bijepax_module = None + bijepax_opt = None + forward_with_hidden = None + if h.bijepax_enabled: + bijepax_module = MultiDirectionalBiJEPAX( + d_model=h.model_dim, + fwd_hops=h.bijepax_fwd_hops, + bwd_hops=h.bijepax_bwd_hops, + cycle=h.bijepax_cycle, + head_dim=h.bijepax_head_dim, + ).to(device).bfloat16() + bijepax_opt = torch.optim.Adam( + bijepax_module.parameters(), lr=h.bijepax_lr, betas=(0.9, 0.99) + ) + forward_with_hidden = ( + torch.compile(base_model.forward_with_hidden, dynamic=False, fullgraph=True) + if compile_enabled + else base_model.forward_with_hidden + ) + log( + "bijepax_config: " + f"weight={h.bijepax_weight} start={h.bijepax_start_frac} " + f"end={h.bijepax_end_frac} fwd={h.bijepax_fwd_hops} " + f"bwd={h.bijepax_bwd_hops} cycle={h.bijepax_cycle} " + f"head_dim={h.bijepax_head_dim} stride={h.bijepax_stride} " + f"lr={h.bijepax_lr} training_only=true" + ) + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + _clip_params = [p for p in base_model.parameters() if p.requires_grad] + def step_fn(step, lr_scale, frac_now=0.0): + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_weight_current.fill_(jepa_weight_at(h, step, frac_now)) + bijepax_weight = ( + bijepax_weight_at( + frac_now, + h.bijepax_start_frac, + h.bijepax_end_frac, + h.bijepax_weight, + ) + if bijepax_module is not None + else 0.0 + ) + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if bijepax_module is not None and bijepax_weight > 0.0: + ce_loss, hidden = forward_with_hidden( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + loss = ce_loss + bijepax_module( + hidden, bijepax_weight, h.bijepax_stride + ) + else: + loss = model( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + if step <= h.muon_momentum_warmup_steps: + + frac = ( + + min(step / h.muon_momentum_warmup_steps, 1.0) + + if h.muon_momentum_warmup_steps > 0 + + else 1.0 + + ) + + muon_momentum = ( + + 1 - frac + + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + + for group in optimizers.optimizer_muon.param_groups: + + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(_clip_params, h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + if bijepax_opt is not None and bijepax_weight > 0.0: + bijepax_opt.step() + bijepax_opt.zero_grad(set_to_none=True) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + _live_state = base_model.state_dict(keep_vars=True) + ema_state = { + name: t.detach().float().clone() + for (name, t) in _live_state.items() + if not name.startswith("jepa_predictor.") + and not name.startswith("jepa_bwd_predictor.") + and name != "jepa_weight_current" + } + _ema_pairs = [ + (ema_state[name], t) + for (name, t) in _live_state.items() + if name in ema_state + ] + ema_decay = h.ema_decay + training_time_ms = 0.0 + forced_stop_step = int(os.environ.get("FORCE_STOP_STEP", "0")) + stop_after_step = forced_stop_step if forced_stop_step > 0 else None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale, frac) + with torch.no_grad(): + for ema_t, t in _ema_pairs: + ema_t.mul_(ema_decay).add_(t.detach(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + forced_stop_step <= 0 + and max_wallclock_ms is not None + and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and forced_stop_step <= 0 and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_predictor = None + log("jepa_lite_predictor_removed_before_serialize: true") + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + base_model.jepa_bwd_predictor = None + log("bijepa_backward_predictor_removed_before_serialize: true") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + global BOS_ID + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + quantize_only = os.environ.get("QUANTIZE_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_eval_only: true") + elif quantize_only: + log("QUANTIZE_ONLY=1 — skipping training, loading saved full-precision checkpoint") + log(f"quantize_only checkpoint: {h.model_path}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_quantize_only: true") + if BOS_ID is None: + BOS_ID = 1 + base_model = GPT(h).to(device).bfloat16() + state = torch.load(h.model_path, map_location="cpu") + base_model.load_state_dict(state, strict=True) + del state + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + else: + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_after_training: true") + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if os.environ.get("PREQUANT_ONLY", "0") == "1": + log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") + return + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + if os.environ.get("DISABLE_COMPILE", "0") == "1": + log("eval_compile:disabled_by_env") + compiled_model = eval_model + compiled_forward_logits = eval_model.forward_logits + else: + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if h.ttt_enabled or not h.ppm_mixer_enabled: + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + def _fwd_ttt_hint_inner(input_ids, target_ids, lora, hint_ids): + return ttt_model.forward_ttt( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + + _fwd_ttt_compiled_inner = None + _fwd_ttt_hint_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora, hint_ids=None): + nonlocal _fwd_ttt_compiled_inner, _fwd_ttt_hint_compiled_inner + if os.environ.get("DISABLE_COMPILE", "0") == "1": + if hint_ids is not None: + return _fwd_ttt_hint_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + return _fwd_ttt_inner(input_ids, target_ids, lora=lora) + if hint_ids is not None: + if _fwd_ttt_hint_compiled_inner is None: + _fwd_ttt_hint_compiled_inner = torch.compile( + _fwd_ttt_hint_inner, dynamic=True + ) + return _fwd_ttt_hint_compiled_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr * h.ttt_local_lr_mult, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + train_warmup_lens = [h.ttt_chunk_size] + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > h.ttt_chunk_size: + train_warmup_lens.append(train_window) + for ctx_len in train_warmup_lens: + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + if h.ngram_tilt_enabled: + ctx_len = h.ttt_eval_seq_len + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + hintw = torch.randint( + 0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64 + ) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + fwd_ttt_compiled(xw, yw, lora=wl, hint_ids=hintw) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + precomputed_hints = None + if h.ngram_tilt_enabled and h.ngram_hint_precompute_outside: + log("ngram_tilt:precomputing hints before TTT eval timer") + precomputed_hints = _compute_ngram_hints_for_val(h, val_data, log0=log) + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, + ttt_model, + device, + val_data, + forward_ttt_train=fwd_ttt_compiled, + precomputed_hints=precomputed_hints, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + if h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, ttt_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del ttt_model + elif h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, eval_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del eval_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 64 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +==================================================================================================== +train_shards: 80 +val_tokens: 47851520 +compile:disabled_by_env +model_params:35945673 +bijepax_config: weight=0.01 start=0.35 end=0.8 fwd=(4,) bwd=(4,) cycle=False head_dim=32 stride=64 lr=0.001 training_only=true +gptq:reserving 0s, effective=599500ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 8.9980 val_bpb: 4.1115 +1/20000 train_loss: 8.9988 train_time: 0.0m tok/s: 4316911 +2/20000 train_loss: 12.8490 train_time: 0.0m tok/s: 3773018 +3/20000 train_loss: 10.2358 train_time: 0.0m tok/s: 3620744 +4/20000 train_loss: 8.6660 train_time: 0.0m tok/s: 3552879 +5/20000 train_loss: 7.9135 train_time: 0.0m tok/s: 3514777 +250/20000 train_loss: 2.9125 train_time: 1.0m tok/s: 3364437 +500/20000 train_loss: 2.5540 train_time: 1.9m tok/s: 3366468 +750/20000 train_loss: 2.6987 train_time: 2.9m tok/s: 3367170 +layer_loop:enabled step:899 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/20000 train_loss: 2.7476 train_time: 4.1m tok/s: 3205099 +1250/20000 train_loss: 2.5457 train_time: 5.5m tok/s: 2953029 +1500/20000 train_loss: 2.5150 train_time: 7.0m tok/s: 2805985 +1750/20000 train_loss: 2.5345 train_time: 8.5m tok/s: 2709881 +2000/20000 train_loss: 2.4815 train_time: 9.9m tok/s: 2642045 +2012/20000 val_loss: 2.4311 val_bpb: 1.1108 +stopping_early: wallclock_cap train_time: 599586ms step: 2012/20000 +peak memory allocated: 46676 MiB reserved: 52266 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.42899363 val_bpb:1.10988323 eval_time:9910ms +Serialized model: 135418111 bytes +Code size (uncompressed): 213292 bytes +Code size (compressed): 41999 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gate_int8_row: blocks.attn.attn_gate_w + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 105.0s +Serialized model quantized+pergroup: 15957540 bytes +Total submission size quantized+pergroup: 15999539 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.7s +eval_compile:disabled_by_env +diagnostic quantized val_loss:2.44155528 val_bpb:1.11562304 eval_time:9926ms +beginning PPM sliding eval +ppm_mixer val_bpb:0.97206308 eval_time:453715ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45044876 val_bpb:0.97206308 eval_time:499038ms diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed42.txt b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed42.txt new file mode 100644 index 0000000000..de9bd09549 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed42.txt @@ -0,0 +1,5251 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.99 + bijepa_backward_weight: 1.0 + bijepa_enabled: False + bijepax_bwd_hops: (4,) + bijepax_cycle: False + bijepax_enabled: True + bijepax_end_frac: 0.8 + bijepax_fwd_hops: (4,) + bijepax_head_dim: 32 + bijepax_lr: 0.001 + bijepax_start_frac: 0.35 + bijepax_stride: 64 + bijepax_weight: 0.01 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 512 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + jepa_lite_enabled: False + jepa_lite_end_frac: 0.85 + jepa_lite_frequency: 8 + jepa_lite_hidden: 128 + jepa_lite_horizon: 4 + jepa_lite_lr: 0.01 + jepa_lite_start_frac: 0.35 + jepa_lite_stride: 16 + jepa_lite_weight: 0.03 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: True + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2500 + ppm_dump_inputs: False + ppm_h: 0.999 + ppm_l: 0.18 + ppm_mixer_enabled: True + ppm_order: 5 + ppm_t: 0.8 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 250 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: False + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_local_lr_mult: 0.75 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mask: no_qv + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_q_lora: False + ttt_train_window_tokens: 0 + ttt_v_lora: False + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_bytes_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Source code: +==================================================================================================== +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +class _BiJEPAXHead(nn.Module): + def __init__(self, d_model: int, head_dim: int): + super().__init__() + self.w1 = nn.Linear(d_model, head_dim, bias=False) + self.ln = nn.LayerNorm(head_dim) + self.w2 = nn.Linear(head_dim, d_model, bias=False) + nn.init.orthogonal_(self.w1.weight) + nn.init.zeros_(self.w2.weight) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.gelu(self.ln(self.w1(x.float())))) + + +class MultiDirectionalBiJEPAX(nn.Module): + """Training-only multi-hop bidirectional/cycle JEPA regularizer.""" + + def __init__( + self, + d_model: int, + fwd_hops: tuple[int, ...] = (1, 4, 16), + bwd_hops: tuple[int, ...] = (1, 4), + cycle: bool = True, + head_dim: int = 64, + ): + super().__init__() + self.fwd_hops = list(fwd_hops) + self.bwd_hops = list(bwd_hops) + self.cycle = cycle + self.fwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) + self.bwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in bwd_hops] + ) + self.cyc_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) if cycle else nn.ModuleList() + + @staticmethod + def _cos_loss(pred: Tensor, tgt: Tensor) -> Tensor: + pred = F.normalize(pred, dim=-1) + tgt = F.normalize(tgt.float().detach(), dim=-1) + return (1.0 - (pred * tgt).sum(dim=-1)).mean() + + def forward(self, hidden: Tensor, weight: float, stride: int) -> Tensor: + if weight <= 0.0 or hidden.size(1) < 2: + return hidden.new_zeros(()) + stride = max(1, int(stride)) + seq_len = hidden.size(1) + losses: list[Tensor] = [] + for i, hop in enumerate(self.fwd_hops): + if seq_len <= hop: + continue + src = hidden[:, : seq_len - hop : stride, :] + tgt = hidden[:, hop : hop + src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + src = src[:, :n, :] + tgt = tgt[:, :n, :] + losses.append(self._cos_loss(self.fwd_heads[i](src), tgt)) + if self.cycle and i < len(self.cyc_heads): + losses.append(self._cos_loss(self.cyc_heads[i](tgt.detach()), src)) + for i, hop in enumerate(self.bwd_hops): + if seq_len <= hop: + continue + src = hidden[:, hop::stride, :] + tgt = hidden[:, : src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + losses.append(self._cos_loss(self.bwd_heads[i](src[:, :n, :]), tgt[:, :n, :])) + if not losses: + return hidden.new_zeros(()) + return weight * (sum(losses) / len(losses)) + + +def bijepax_weight_at(frac: float, start: float, end: float, peak: float) -> float: + if frac < start or frac >= end: + return 0.0 + span = max(end - start, 1e-9) + ramp = min(span * 0.15, 0.08) + if frac < start + ramp: + return peak * (frac - start) / max(ramp, 1e-9) + if frac >= end - ramp: + return peak * (end - frac) / max(ramp, 1e-9) + return peak + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.85)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + # v13 is the sidecar-aware PPM lane. These defaults match the under-cap + # H100 package runs instead of the older TTT-first v12 defaults. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.1)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.99)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 512)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 80)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_local_lr_mult = float(os.environ.get("TTT_LOCAL_LR_MULT", 0.75)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_train_window_tokens = int(os.environ.get("TTT_TRAIN_WINDOW_TOKENS", 0)) + # V19: PR #1886 (renqianluo) + sunnypatneedi research log 2026-04-28 found that + # the Triton fused-CE kernel's fp32-accumulation interacts with warm-start LoRA-A + # to destabilize seeds 314/1337 at TTT_WEIGHT_DECAY=1.0. Raising the default to + # 2.0 prevents seed collapse without measurably moving stable seeds. + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 0.5)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.99)) + ttt_mask = os.environ.get("TTT_MASK", "no_qv").strip().lower() + _ttt_q_default = "1" + _ttt_v_default = "1" + if ttt_mask in ("", "all", "baseline_all"): + pass + elif ttt_mask == "no_q": + _ttt_q_default = "0" + elif ttt_mask == "no_v": + _ttt_v_default = "0" + elif ttt_mask == "no_qv": + _ttt_q_default = "0" + _ttt_v_default = "0" + else: + raise ValueError(f"Unsupported TTT_MASK={ttt_mask!r}") + ttt_q_lora = bool(int(os.environ.get("TTT_Q_LORA", _ttt_q_default))) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_v_lora = bool(int(os.environ.get("TTT_V_LORA", _ttt_v_default))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "pergroup") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 0.5)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2500)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 3)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 7)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 14.0)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 11.5)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "1"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "1"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "1"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 0.5)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + lqer_scope = os.environ.get("LQER_SCOPE", "all") + lqer_gain_select = bool(int(os.environ.get("LQER_GAIN_SELECT", "0"))) + awq_lite_enabled = bool(int(os.environ.get("AWQ_LITE_ENABLED", "1"))) + awq_lite_bits = int(os.environ.get("AWQ_LITE_BITS", "8")) + awq_lite_group_top_k = int(os.environ.get("AWQ_LITE_GROUP_TOP_K", "1")) + awq_lite_group_size = int(os.environ.get("AWQ_LITE_GROUP_SIZE", "64")) + # PR #1145/#1967 online n-gram tilt. This is a causal scoring overlay: + # prefix-only token/within-word/word experts propose one hint token, then + # the per-token NLL is adjusted with closed-form softmax renormalization. + ngram_tilt_enabled = bool(int(os.environ.get("NGRAM_TILT_ENABLED", "1"))) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + within_tau = float(os.environ.get("WITHIN_TAU", "0.450")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.750")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "0.650")) + word_boost = float(os.environ.get("WORD_BOOST", "0.750")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + ngram_hint_precompute_outside = bool(int(os.environ.get("NGRAM_HINT_PRECOMPUTE_OUTSIDE", "1"))) + ppm_mixer_enabled = bool(int(os.environ.get("PPM_MIXER_ENABLED", "1"))) + ppm_order = int(os.environ.get("PPM_ORDER", "5")) + ppm_h = float(os.environ.get("PPM_H", "0.995")) + ppm_l = float(os.environ.get("PPM_L", "0.20")) + ppm_t = float(os.environ.get("PPM_T", "0.80")) + ppm_dump_inputs = bool(int(os.environ.get("PPM_DUMP_INPUTS", "0"))) + # JEPA-Lite training regularizer from PR #2027. The predictor is never + # serialized; it only nudges hidden states during training. + jepa_lite_enabled = bool(int(os.environ.get("JEPA_LITE_ENABLED", "0"))) + jepa_lite_weight = float(os.environ.get("JEPA_LITE_WEIGHT", "0.03")) + jepa_lite_hidden = int(os.environ.get("JEPA_LITE_HIDDEN", "128")) + jepa_lite_horizon = int(os.environ.get("JEPA_LITE_HORIZON", "4")) + jepa_lite_stride = int(os.environ.get("JEPA_LITE_STRIDE", "16")) + jepa_lite_frequency = int(os.environ.get("JEPA_LITE_FREQUENCY", "8")) + jepa_lite_start_frac = float(os.environ.get("JEPA_LITE_START_FRAC", "0.35")) + jepa_lite_end_frac = float(os.environ.get("JEPA_LITE_END_FRAC", "0.85")) + jepa_lite_lr = float(os.environ.get("JEPA_LITE_LR", "0.01")) + bijepa_enabled = bool(int(os.environ.get("BIJEPA_ENABLED", "0"))) + bijepa_backward_weight = float(os.environ.get("BIJEPA_BACKWARD_WEIGHT", "1.0")) + bijepax_enabled = bool(int(os.environ.get("BIJEPAX_ENABLED", "0"))) + bijepax_weight = float(os.environ.get("BIJEPAX_WEIGHT", "0.04")) + bijepax_start_frac = float(os.environ.get("BIJEPAX_START_FRAC", "0.05")) + bijepax_end_frac = float(os.environ.get("BIJEPAX_END_FRAC", "0.80")) + bijepax_fwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_FWD_HOPS", "1,4,16").split(",") if x + ) + bijepax_bwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_BWD_HOPS", "1,4").split(",") if x + ) + bijepax_cycle = bool(int(os.environ.get("BIJEPAX_CYCLE", "1"))) + bijepax_head_dim = int(os.environ.get("BIJEPAX_HEAD_DIM", "64")) + bijepax_stride = int(os.environ.get("BIJEPAX_STRIDE", "8")) + bijepax_lr = float(os.environ.get("BIJEPAX_LR", "3e-4")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "1"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + if self.caseops_enabled: + self.base_bytes_lut = None + self.has_leading_space_lut = None + self.is_boundary_token_lut = None + else: + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._prefetch_queue = [] + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + targets = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + inputs.copy_(buf[:-1]) + targets.copy_(buf[1:]) + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + while len(self._prefetch_queue) < 2: + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + inputs, targets, cu_seqlens, max_seqlen = self._prefetch_queue.pop(0).result() + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def fused_log_softmax_dual_gather_kernel( + logits_ptr, + target_ids_ptr, + hint_ids_ptr, + log_p_y_out_ptr, + log_q_h_out_ptr, + n_rows, + n_cols, + block_cols: tl.constexpr, +): + row_idx = tl.program_id(0) + if row_idx >= n_rows: + return + target = tl.load(target_ids_ptr + row_idx) + hint = tl.load(hint_ids_ptr + row_idx) + row_offset = row_idx * n_cols + target_logit = tl.load(logits_ptr + row_offset + target).to(tl.float32) + hint_logit = tl.load(logits_ptr + row_offset + hint).to(tl.float32) + max_val = -float("inf") + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=-float("inf") + ).to(tl.float32) + max_val = tl.maximum(max_val, tl.max(vals, axis=0)) + sum_exp = tl.zeros((), dtype=tl.float32) + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=0.0 + ).to(tl.float32) + sum_exp += tl.sum(tl.where(mask, tl.exp(vals - max_val), 0.0), axis=0) + lse = max_val + tl.log(sum_exp) + tl.store(log_p_y_out_ptr + row_idx, target_logit - lse) + tl.store(log_q_h_out_ptr + row_idx, hint_logit - lse) + + +def fused_log_softmax_dual_gather(logits, target_ids, hint_ids): + bsz, seqlen, vocab = logits.shape + n_rows = bsz * seqlen + logits_flat = logits.reshape(n_rows, vocab).contiguous() + target_flat = target_ids.reshape(n_rows).contiguous() + hint_flat = hint_ids.reshape(n_rows).contiguous() + log_p_y_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + log_q_h_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + fused_log_softmax_dual_gather_kernel[(n_rows,)]( + logits_flat, + target_flat, + hint_flat, + log_p_y_out, + log_q_h_out, + n_rows, + vocab, + block_cols=1024, + num_warps=8, + ) + return log_p_y_out.reshape(bsz, seqlen), log_q_h_out.reshape(bsz, seqlen) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.3 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.3 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.3 * c0) + aux1 = tl.where(c1 > 0, c1, 0.3 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 256, 128, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.3).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.fused_ce_enabled = bool(h.fused_ce_enabled) + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + # V19: Asymmetric Logit Rescale (PR #1923 jorge-asenjo). + # Two learnable softcap scales applied on the EVAL path (forward_logits + + # forward_ttt). Init to logit_softcap so the layer is identity at step 0. + # Train path keeps the single fused softcap to preserve PR #1855 numerics. + self.asym_logit_enabled = bool(int(os.environ.get("ASYM_LOGIT_RESCALE", "1"))) + if self.asym_logit_enabled: + self.softcap_pos = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.softcap_neg = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.jepa_horizon = h.jepa_lite_horizon + self.jepa_stride = h.jepa_lite_stride + self.register_buffer( + "jepa_weight_current", + torch.zeros((), dtype=torch.float32), + persistent=False, + ) + self.jepa_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled + else None + ) + self.bijepa_enabled = bool(h.bijepa_enabled) + self.bijepa_backward_weight = float(h.bijepa_backward_weight) + self.jepa_bwd_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled and h.bijepa_enabled + else None + ) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self.tok_emb(input_ids) + # SmearGate (PR #1667). lam=0 + W=0 -> identity at init. + # Cross-doc leak fix: zero the prev-token smear at any position whose current token + # is BOS, so the BOS embedding starting doc N+1 in a packed stream is not + # contaminated by doc N's last token (audited issue on PR#1797 base). + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + return F.linear(hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def _apply_asym_softcap(self, logits): + # V19: Asymmetric softcap (PR #1923). Splits the logit_softcap scalar into + # learnable positive/negative branches. Score-first preserved: still a + # bounded, normalized post-projection nonlinearity feeding a standard + # softmax over the full vocab. + sp = self.softcap_pos.to(logits.dtype) + sn = self.softcap_neg.to(logits.dtype) + return torch.where(logits > 0, sp * torch.tanh(logits / sp), sn * torch.tanh(logits / sn)) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + if self.asym_logit_enabled: + return self._apply_asym_softcap(logits_proj) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def jepa_aux_loss(self, features): + if self.jepa_predictor is None: + return features.new_zeros(()) + horizon = max(1, int(self.jepa_horizon)) + stride = max(1, int(self.jepa_stride)) + if features.size(1) <= horizon: + return features.new_zeros(()) + src = features[:, :-horizon:stride, :] + tgt = features[:, horizon::stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + return features.new_zeros(()) + src = F.normalize(src[:, :n, :].float(), dim=-1) + tgt = F.normalize(tgt[:, :n, :].detach().float(), dim=-1) + pred = F.normalize(self.jepa_predictor(src), dim=-1) + loss = (1.0 - (pred * tgt).sum(dim=-1)).mean() + if self.jepa_bwd_predictor is not None: + bwd_src = F.normalize(features[:, horizon::stride, :][:, :n, :].float(), dim=-1) + bwd_tgt = F.normalize( + features[:, :-horizon:stride, :][:, :n, :].detach().float(), + dim=-1, + ) + bwd_pred = F.normalize(self.jepa_bwd_predictor(bwd_src), dim=-1) + bwd_loss = (1.0 - (bwd_pred * bwd_tgt).sum(dim=-1)).mean() + loss = loss + self.bijepa_backward_weight * bwd_loss + return loss * self.jepa_weight_current.to( + device=features.device, + dtype=features.dtype, + ) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss + + def forward_with_hidden(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden( + input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss, hidden + + def forward_ttt(self, input_ids, target_ids, lora, hint_ids=None): + x = self.tok_emb(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + # Cross-doc leak fix: see _forward_hidden comment. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + # V19: same asymmetric softcap on the TTT eval path. + if self.asym_logit_enabled: + logits = self._apply_asym_softcap(logits) + else: + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + if hint_ids is None: + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + if not logits.requires_grad: + log_p_y, log_q_h = fused_log_softmax_dual_gather( + logits, target_ids, hint_ids.clamp(min=0) + ) + return -log_p_y, log_q_h + ls = F.log_softmax(logits.float(), dim=-1) + log_p_y = ls.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) + log_q_h = ls.gather(-1, hint_ids.clamp(min=0).unsqueeze(-1)).squeeze(-1) + return -log_p_y, log_q_h + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__( + self, bsz, model, rank, + q_lora=True, k_lora=True, v_lora=True, mlp_lora=True, o_lora=True, + ): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if q_lora + else None + ) + self.v_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if v_lora + else None + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad) + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update, alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + jepa_params = ( + list(base_model.jepa_predictor.parameters()) + if getattr(base_model, "jepa_predictor", None) is not None + else [] + ) + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + jepa_params.extend(base_model.jepa_bwd_predictor.parameters()) + if jepa_params: + self.optimizer_jepa = torch.optim.AdamW( + [ + { + "params": jepa_params, + "lr": h.jepa_lite_lr, + "base_lr": h.jepa_lite_lr, + } + ], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=0.0, + fused=True, + ) + else: + self.optimizer_jepa = None + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + if self.optimizer_jepa is not None: + self.optimizers.append(self.optimizer_jepa) + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_params.extend(jepa_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + self._aux_stream = torch.cuda.Stream() + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self._aux_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._aux_stream): + self.optimizer_tok.step() + self.optimizer_scalar.step() + if self.optimizer_jepa is not None: + self.optimizer_jepa.step() + self.optimizer_muon.step() + torch.cuda.current_stream().wait_stream(self._aux_stream) + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + act_sumsq = {} + act_counts = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + x_sq = x.square().sum(dim=0) + x_count = x.shape[0] + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x_sq + act_counts[name] += x_count + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + y.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += y.square().sum(dim=0) + act_counts[name] += y.shape[0] + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + h_act.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += h_act.square().sum(dim=0) + act_counts[name] += h_act.shape[0] + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + act_stats = {} + for name, sumsq in act_sumsq.items(): + count = max(act_counts.get(name, 0), 1) + act_stats[name] = (sumsq / count).sqrt().cpu() + return hessians, act_stats + + +def gptq_quantize_weight( + w, + H, + clip_sigmas=3.0, + clip_range=63, + block_size=128, + protect_groups=None, + group_size=None, + protect_clip_range=None, +): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + H_flip = torch.flip(H, dims=(0, 1)) + L_flip = torch.linalg.cholesky(H_flip) + U = torch.flip(L_flip, dims=(0, 1)) + eye = torch.eye(H.shape[0], device=H.device, dtype=H.dtype) + Hinv = torch.linalg.solve_triangular(U, eye, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + protect_meta = None + protect_mask_perm = None + s_hi = None + sf_hi = None + if ( + protect_groups + and group_size is not None + and protect_clip_range is not None + and protect_clip_range > clip_range + ): + protect_mask = torch.zeros(cols, dtype=torch.bool) + starts = [] + for (start, end) in protect_groups: + if start < 0 or end > cols or end <= start: + continue + protect_mask[start:end] = True + starts.append(start) + if starts: + protect_mask_perm = protect_mask[perm] + s_hi = (clip_sigmas * row_std / protect_clip_range).clamp_min(1e-10).to( + torch.float16 + ) + sf_hi = s_hi.float() + protect_meta = { + "starts": torch.tensor(starts, dtype=torch.int16), + "size": int(group_size), + "s_hi": s_hi, + } + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + if protect_mask_perm is not None and bool(protect_mask_perm[i1 + j]): + q_col = torch.clamp( + torch.round(w_col / sf_hi), + -protect_clip_range, + protect_clip_range, + ) + w_recon = q_col.float() * sf_hi + else: + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + w_recon = q_col.float() * sf + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - w_recon) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s, protect_meta + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def _lqer_fit_quantized(E, h): + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + if r <= 0: + return None + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + A_hat = qA.float() * float(sA) + g_sz = qB.numel() // sB.numel() + B_hat = (qB.reshape(-1, g_sz).float() * sB.float().view(-1, 1)).reshape( + qB.shape + ) + return { + "kind": "asym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + A_hat = qA.float() * sA.float().view(-1, 1) + B_hat = qB.float() * sB.float().view(-1, 1) + return { + "kind": "sym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + + +def _awq_lite_group_candidates(w, act_rms, group_size): + cols = w.shape[1] + n_groups = cols // group_size + if n_groups <= 0: + return [] + weight_score = w.float().abs().mean(dim=0) + saliency = act_rms.float() * weight_score + cands = [] + for gi in range(n_groups): + start = gi * group_size + end = start + group_size + score = float(saliency[start:end].sum()) + cands.append((score, start, end)) + return cands + + +def gptq_mixed_quantize(state_dict, hessians, act_stats, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + awq_on = bool(getattr(h, "awq_lite_enabled", False)) + lqer_cands = {} + awq_selected = collections.defaultdict(list) + if awq_on: + awq_cands = [] + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if t.is_floating_point() and t.numel() > 65536 and name in act_stats: + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + if bits < h.awq_lite_bits: + for score, start, end in _awq_lite_group_candidates( + t, act_stats[name], h.awq_lite_group_size + ): + awq_cands.append((score, name, start, end)) + awq_cands.sort(key=lambda x: -x[0]) + for (_score, name, start, end) in awq_cands[: h.awq_lite_group_top_k]: + awq_selected[name].append((start, end)) + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + q, s, protect_meta = gptq_quantize_weight( + t, + hessians[name], + clip_sigmas=cs, + clip_range=clip_range, + protect_groups=awq_selected.get(name), + group_size=h.awq_lite_group_size if name in awq_selected else None, + protect_clip_range=(2 ** (h.awq_lite_bits - 1) - 1) + if name in awq_selected + else None, + ) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + W_q = q.float() * s.float().view(-1, 1) + if protect_meta is not None: + result[name + ".awqg_start"] = protect_meta["starts"] + result[name + ".awqg_s_hi"] = protect_meta["s_hi"] + result[name + ".awqg_size"] = torch.tensor( + protect_meta["size"], dtype=torch.int16 + ) + meta[name] = meta[name] + f"+awqgrpint{h.awq_lite_bits}" + gsz = protect_meta["size"] + for start in protect_meta["starts"].tolist(): + W_q[:, start : start + gsz] = ( + q[:, start : start + gsz].float() + * protect_meta["s_hi"].float().view(-1, 1) + ) + if lqer_on: + # LQER is fit on top of the fully realized GPTQ base, which already + # includes any higher-precision AWQ-protected groups. + scope = str(getattr(h, "lqer_scope", "all")).lower() + scope_ok = ( + scope == "all" + or (scope == "mlp" and ".mlp." in name) + or (scope == "attn" and ".attn." in name) + or (scope == "embed" and "tok_emb" in name) + ) + if scope_ok: + E = t.float() - W_q + err_norm = float(E.norm()) + if err_norm > 0: + lqer_cands[name] = (E, err_norm) + if lqer_on and lqer_cands: + if bool(getattr(h, "lqer_gain_select", False)): + scored = [] + for (name, (E, base_err)) in lqer_cands.items(): + fit = _lqer_fit_quantized(E, h) + if fit is None: + continue + new_err = float((E - fit["delta"]).norm()) + gain = base_err - new_err + if gain > 0: + scored.append((gain, name, fit)) + scored.sort(key=lambda x: -x[0]) + for (_gain, name, fit) in scored[: h.lqer_top_k]: + if fit["kind"] == "asym": + result[name + ".lqA_a"] = fit["qA"] + result[name + ".lqAs_a"] = fit["sA"] + result[name + ".lqB_a"] = fit["qB"] + result[name + ".lqBs_a"] = fit["sB"] + meta[name] = meta[name] + "+lqer_asym" + else: + result[name + ".lqA"] = fit["qA"] + result[name + ".lqAs"] = fit["sA"] + result[name + ".lqB"] = fit["qB"] + result[name + ".lqBs"] = fit["sB"] + meta[name] = meta[name] + "+lqer" + else: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "awqgrpint" in info: + starts = result[name + ".awqg_start"].tolist() + s_hi = result[name + ".awqg_s_hi"].float() + gsz = int(result[name + ".awqg_size"].item()) + for start in starts: + W[:, start : start + gsz] = ( + q[:, start : start + gsz].float() * s_hi.view(-1, 1) + ) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +# ── Per-group lrzip compression (ported from PR#1586 via PR#1667/1729) ──────── + +_GROUP_ORDER = [ + "_tok_emb.weight.q", + "attn.c_k.weight.q", "attn.c_q.weight.q", + "attn.c_v.weight.q", "attn.proj.weight.q", + "mlp.fc.weight.q", "mlp.proj.weight.q", +] +_SIMSORT_KEYS = {"_tok_emb.weight.q", "attn.c_q.weight.q", "mlp.fc.weight.q"} +_PACK_MAGIC = b"PGRP" + + +def _similarity_sort_l1(matrix): + import numpy as _np + n = matrix.shape[0] + used = _np.zeros(n, dtype=bool) + order = [0] + used[0] = True + cur = matrix[0].astype(_np.float32) + for _ in range(n - 1): + dists = _np.sum(_np.abs(matrix[~used].astype(_np.float32) - cur), axis=1) + unused = _np.where(~used)[0] + best = unused[_np.argmin(dists)] + order.append(best) + used[best] = True + cur = matrix[best].astype(_np.float32) + return _np.array(order, dtype=_np.uint16) + + +def _lrzip_compress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.bin") + out = f"{inp}.lrz" + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-z", "-L", "9", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _lrzip_decompress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.lrz") + out = os.path.join(tmpdir, f"{label}.bin") + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-d", "-f", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _pack_streams(streams): + import struct + n = len(streams) + hdr = _PACK_MAGIC + struct.pack("", p) + if m: + return bytes([int(m.group(1), 16)]) + return (" " + p[1:]).encode() if p.startswith("▁") else p.encode() + + +def _ppm_mixture_bpb(tgt_np, lp_np, sp, O=4, H=0.9, L_=0.05, T=0.9, token_byte_lens_np=None): + V = sp.vocab_size() + piece_bytes = [None] * V + piece_lens = np.zeros(V, dtype=np.int32) + for i in range(V): + b = _ppm_piece_bytes(sp, i) + piece_bytes[i] = b + piece_lens[i] = len(b) + if token_byte_lens_np is None: + per_tok_len = piece_lens[tgt_np] + bs = b''.join(piece_bytes[int(t)] for t in tgt_np) + kept_lp = lp_np + else: + chunks = [] + kept_lp_parts = [] + lens_parts = [] + for t, lp, side_len in zip(tgt_np, lp_np, token_byte_lens_np): + side_len = int(side_len) + if side_len <= 0: + continue + b = piece_bytes[int(t)] + if not b: + continue + if len(b) > side_len: + b = b[:side_len] + elif len(b) < side_len: + b = b + b[-1:] * (side_len - len(b)) + chunks.append(b) + kept_lp_parts.append(float(lp)) + lens_parts.append(side_len) + if not chunks: + return float("inf") + bs = b''.join(chunks) + per_tok_len = np.asarray(lens_parts, dtype=np.int32) + kept_lp = np.asarray(kept_lp_parts, dtype=np.float64) + N = len(bs) + rep_lp = np.repeat(kept_lp.astype(np.float64), per_tok_len) + rep_len = np.repeat(per_tok_len.astype(np.float64), per_tok_len) + nlp = np.where(rep_len > 0, rep_lp / rep_len, 0.0) + tabs = [dict() for _ in range(O + 1)] + plp = np.empty(N, dtype=np.float64) + cf = np.empty(N, dtype=np.float64) + LN256 = math.log(1 / 256) + log_ = math.log + h_ctx = b'' + for i in range(N): + x = bs[i] + if i == 0: + plp[i] = LN256 + cf[i] = 1 / 256 + else: + esc = 1.0 + pf = 0.0 + cf_mx = 0 + cf_tot = 256 + cf_seen = False + lim = O if i > O else i + for o in range(lim, -1, -1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + continue + if not cf_seen: + cf_mx = e[1] + cf_tot = e[0] + cf_seen = True + tot = e[0] + d = e[2] + c = d.get(x, 0) + if c > 0: + pf = esc * (2 * c - 1) / (2 * tot) + break + esc *= len(d) / (2 * tot) + else: + pf = esc / 256 + if pf < 1e-20: + pf = 1e-20 + plp[i] = log_(pf) + cf[i] = (cf_mx / cf_tot) if cf_seen else 1 / 256 + for o in range(O + 1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + tabs[o][k] = [1, 1, {x: 1}] + else: + e[0] += 1 + d = e[2] + cnt = d.get(x, 0) + 1 + d[x] = cnt + if cnt > e[1]: + e[1] = cnt + h_ctx = (h_ctx + bytes([x]))[-O:] + nn_prob = np.exp(nlp) + ppm_prob = np.exp(plp) + + def _mix_bpb(Hv, Lv, Tv): + lam_v = np.where(cf > Tv, Lv, Hv) + pm_v = lam_v * nn_prob + (1 - lam_v) * ppm_prob + return float(-np.log2(np.maximum(pm_v, 1e-300)).sum() / N) + + default_bpb = _mix_bpb(H, L_, T) + if os.environ.get("PPM_SWEEP_GRID", "0") == "1": + hs = [float(x) for x in os.environ.get("PPM_SWEEP_HS", str(H)).split(",") if x.strip()] + ls = [float(x) for x in os.environ.get("PPM_SWEEP_LS", str(L_)).split(",") if x.strip()] + ts = [float(x) for x in os.environ.get("PPM_SWEEP_TS", str(T)).split(",") if x.strip()] + combo_count = len(hs) * len(ls) * len(ts) + max_combos = int(os.environ.get("PPM_SWEEP_MAX_COMBOS", "256")) + if combo_count > max_combos and os.environ.get("PPM_SWEEP_ALLOW_SLOW", "0") != "1": + log( + f"ppm_sweep skipped: combos={combo_count} max={max_combos}; " + "dump inputs and replay offline, or set PPM_SWEEP_ALLOW_SLOW=1" + ) + return default_bpb + best = (default_bpb, H, L_, T) + for Hv in hs: + for Lv in ls: + for Tv in ts: + bpb = _mix_bpb(Hv, Lv, Tv) + if bpb < best[0]: + best = (bpb, Hv, Lv, Tv) + log( + f"ppm_sweep best_bpb:{best[0]:.8f} H={best[1]} L={best[2]} T={best[3]} " + f"default_bpb:{default_bpb:.8f}" + ) + if os.environ.get("PPM_SWEEP_APPLY", "0") == "1": + return best[0] + return default_bpb + + +def eval_val_ppm_sliding(h, device, val_data, model, batch_seqs=32): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + model.eval() + seq_len = h.eval_seq_len + stride = h.eval_stride + context_size = seq_len - stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + tga_local = [] + lpa_local = [] + bla_local = [] + fwd_fn = model.module.forward_logits if hasattr(model, 'module') else model.forward_logits + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = fwd_fn(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction='none', + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + if val_data.val_bytes is not None: + tb = val_data.val_bytes[ws + s + 1: ws + wlen + 1].to(device=device, dtype=torch.float64) + else: + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + tga_local.append(tgt.cpu().to(torch.int64)) + lpa_local.append((-scored_nll).cpu().to(torch.float64)) + bla_local.append(tb.cpu().to(torch.int32)) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss, val_bpb = _loss_bpb(loss_sum, token_count, byte_count) + if h.ppm_mixer_enabled: + tga_local_cat = torch.cat(tga_local) if tga_local else torch.zeros(0, dtype=torch.int64) + lpa_local_cat = torch.cat(lpa_local) if lpa_local else torch.zeros(0, dtype=torch.float64) + bla_local_cat = torch.cat(bla_local) if bla_local else torch.zeros(0, dtype=torch.int32) + if dist.is_available() and dist.is_initialized(): + local_size = torch.tensor([tga_local_cat.numel()], dtype=torch.int64, device=device) + sizes = [torch.zeros(1, dtype=torch.int64, device=device) for _ in range(h.world_size)] + dist.all_gather(sizes, local_size) + sizes_list = [int(s.item()) for s in sizes] + max_size = max(sizes_list) if sizes_list else 0 + tga_pad = torch.zeros(max_size, dtype=torch.int64, device=device) + lpa_pad = torch.zeros(max_size, dtype=torch.float64, device=device) + bla_pad = torch.zeros(max_size, dtype=torch.int32, device=device) + tga_pad[:tga_local_cat.numel()] = tga_local_cat.to(device) + lpa_pad[:lpa_local_cat.numel()] = lpa_local_cat.to(device) + bla_pad[:bla_local_cat.numel()] = bla_local_cat.to(device) + if h.rank == 0: + gather_t = [torch.zeros(max_size, dtype=torch.int64, device=device) for _ in range(h.world_size)] + gather_l = [torch.zeros(max_size, dtype=torch.float64, device=device) for _ in range(h.world_size)] + gather_b = [torch.zeros(max_size, dtype=torch.int32, device=device) for _ in range(h.world_size)] + else: + gather_t = None + gather_l = None + gather_b = None + dist.gather(tga_pad, gather_t, dst=0) + dist.gather(lpa_pad, gather_l, dst=0) + dist.gather(bla_pad, gather_b, dst=0) + if h.rank == 0: + tga_full = torch.cat([gather_t[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + lpa_full = torch.cat([gather_l[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + bla_full = torch.cat([gather_b[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_full.astype(np.int64), + logp=lpa_full.astype(np.float64), + byte_lens=bla_full.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_full, lpa_full, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_full) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_full.size} N_sidecar_bytes={int(bla_full.sum())}') + val_bpb = mixer_bpb + else: + tga_np = tga_local_cat.numpy() + lpa_np = lpa_local_cat.numpy() + bla_np = bla_local_cat.numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_np.astype(np.int64), + logp=lpa_np.astype(np.float64), + byte_lens=bla_np.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_np, lpa_np, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_np) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_np.size} N_sidecar_bytes={int(bla_np.sum())}') + val_bpb = mixer_bpb + model.train() + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( + device=device, dtype=torch.int32, non_blocking=True + ) + val_byte_count += sidecar_slice.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def jepa_weight_at(h, step, frac): + if not h.jepa_lite_enabled: + return 0.0 + freq = max(1, h.jepa_lite_frequency) + if step % freq != 0 or frac < h.jepa_lite_start_frac: + return 0.0 + weight = h.jepa_lite_weight + if frac > h.jepa_lite_end_frac: + weight *= max( + 0.0, + (1.0 - frac) / max(1e-09, 1.0 - h.jepa_lite_end_frac), + ) + return weight + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + if os.environ.get("VAL_DOC_PREFIX_ONLY", "0") == "1": + return doc_entries[:sample_n] + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def _compute_ngram_hints_for_val(h, val_data, log0=print): + if not getattr(h, "ngram_tilt_enabled", False): + return None + from online_ngram_tilt import build_hints_for_targets + + all_tokens = val_data.val_tokens + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + max_targets = int(os.environ.get("NGRAM_HINT_MAX_TARGETS", "0")) + target_count = targets_np_all.shape[0] + if max_targets > 0: + targets_np = targets_np_all[: min(max_targets, target_count)] + else: + targets_np = targets_np_all + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log0, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + if hint_global.numel() < target_count: + padded_hint = torch.zeros(target_count, dtype=torch.int64) + padded_gate = torch.zeros(target_count, dtype=torch.bool) + padded_boost = torch.zeros(target_count, dtype=torch.float32) + padded_hint[: hint_global.numel()] = hint_global + padded_gate[: gate_global.numel()] = gate_global + padded_boost[: boost_global.numel()] = boost_global + hint_global, gate_global, boost_global = padded_hint, padded_gate, padded_boost + log0( + f"ngram_tilt:precompute_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={hint_global.numel()} computed_targets={targets_np.shape[0]}" + ) + return hint_global, gate_global, boost_global + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, precomputed_hints=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + ngram_hint_global = None + ngram_gate_global = None + ngram_boost_global = None + if precomputed_hints is not None: + ngram_hint_global, ngram_gate_global, ngram_boost_global = precomputed_hints + log( + "ngram_tilt:using_precomputed_hints " + f"total_targets={ngram_hint_global.numel()}" + ) + elif getattr(h, "ngram_tilt_enabled", False): + ngram_hint_global, ngram_gate_global, ngram_boost_global = _compute_ngram_hints_for_val( + h, val_data, log0=log + ) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + local_lr = h.ttt_lora_lr * h.ttt_local_lr_mult + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=local_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=local_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + hint_ids_gpu = None + gate_mask_gpu = None + boost_gpu = None + if ngram_hint_global is not None: + hint_idx_cpu = ( + tok_starts.unsqueeze(1) + col_idx[:context_size].unsqueeze(0) + ).clamp_(min=0, max=ngram_hint_global.numel() - 1) + hint_ids_gpu = ngram_hint_global[hint_idx_cpu].to( + device=device, dtype=torch.int64, non_blocking=True + ) + gate_mask_gpu = ngram_gate_global[hint_idx_cpu].to( + device=device, non_blocking=True + ) + boost_gpu = ngram_boost_global[hint_idx_cpu].to( + device=device, dtype=torch.float32, non_blocking=True + ) + hint_ids_gpu = torch.where(valid, hint_ids_gpu, torch.zeros_like(hint_ids_gpu)) + gate_mask_gpu = gate_mask_gpu & valid + log_q_hint = None + if hint_ids_gpu is not None: + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss, log_q_hint = forward_ttt_train( + x, y, lora=cur_lora, hint_ids=hint_ids_gpu + ) + else: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + if hint_ids_gpu is not None and log_q_hint is not None: + from online_ngram_tilt import apply_tilt_to_ptl_torch_fast + + scored_loss = apply_tilt_to_ptl_torch_fast( + ptl=per_tok_loss, + log_q_hint=log_q_hint, + target_ids=y, + hint_ids=hint_ids_gpu, + gate_mask=gate_mask_gpu, + boost=boost_gpu, + ) + else: + scored_loss = per_tok_loss + with torch.no_grad(): + _accumulate_bpb( + scored_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if scored_loss is not per_tok_loss: + del scored_loss + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + train_x, train_y = x, y + train_chunk_offset = chunk_offset + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > 0 and context_size > max(train_window, chunk_size): + train_window = max(train_window, chunk_size) + train_end = min(context_size, chunk_offset + chunk_size) + train_start = max(0, train_end - train_window) + train_x = x[:, train_start:train_end].contiguous() + train_y = y[:, train_start:train_end].contiguous() + train_chunk_offset = chunk_offset - train_start + for gi in range(h.ttt_grad_steps): + if hint_ids_gpu is not None or gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + train_per_tok_loss = forward_ttt_train( + train_x, train_y, lora=cur_lora + ) + else: + train_per_tok_loss = per_tok_loss + per_doc = train_per_tok_loss[ + :, train_chunk_offset : train_chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + if train_per_tok_loss is not per_tok_loss: + del train_per_tok_loss + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compile_enabled = os.environ.get("DISABLE_COMPILE", "0") != "1" + if compile_enabled: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + else: + log("compile:disabled_by_env") + compiled_model = base_model + compiled_forward_logits = base_model.forward_logits + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + if h.jepa_lite_enabled: + log( + "jepa_lite_config: " + f"weight={h.jepa_lite_weight} hidden={h.jepa_lite_hidden} " + f"horizon={h.jepa_lite_horizon} stride={h.jepa_lite_stride} " + f"frequency={h.jepa_lite_frequency} start_frac={h.jepa_lite_start_frac} " + f"end_frac={h.jepa_lite_end_frac} lr={h.jepa_lite_lr}" + ) + if h.bijepa_enabled: + log( + "bijepa_config: " + f"backward_weight={h.bijepa_backward_weight} " + "dual_predictor=true training_only=true" + ) + bijepax_module = None + bijepax_opt = None + forward_with_hidden = None + if h.bijepax_enabled: + bijepax_module = MultiDirectionalBiJEPAX( + d_model=h.model_dim, + fwd_hops=h.bijepax_fwd_hops, + bwd_hops=h.bijepax_bwd_hops, + cycle=h.bijepax_cycle, + head_dim=h.bijepax_head_dim, + ).to(device).bfloat16() + bijepax_opt = torch.optim.Adam( + bijepax_module.parameters(), lr=h.bijepax_lr, betas=(0.9, 0.99) + ) + forward_with_hidden = ( + torch.compile(base_model.forward_with_hidden, dynamic=False, fullgraph=True) + if compile_enabled + else base_model.forward_with_hidden + ) + log( + "bijepax_config: " + f"weight={h.bijepax_weight} start={h.bijepax_start_frac} " + f"end={h.bijepax_end_frac} fwd={h.bijepax_fwd_hops} " + f"bwd={h.bijepax_bwd_hops} cycle={h.bijepax_cycle} " + f"head_dim={h.bijepax_head_dim} stride={h.bijepax_stride} " + f"lr={h.bijepax_lr} training_only=true" + ) + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + _clip_params = [p for p in base_model.parameters() if p.requires_grad] + def step_fn(step, lr_scale, frac_now=0.0): + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_weight_current.fill_(jepa_weight_at(h, step, frac_now)) + bijepax_weight = ( + bijepax_weight_at( + frac_now, + h.bijepax_start_frac, + h.bijepax_end_frac, + h.bijepax_weight, + ) + if bijepax_module is not None + else 0.0 + ) + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if bijepax_module is not None and bijepax_weight > 0.0: + ce_loss, hidden = forward_with_hidden( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + loss = ce_loss + bijepax_module( + hidden, bijepax_weight, h.bijepax_stride + ) + else: + loss = model( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + if step <= h.muon_momentum_warmup_steps: + + frac = ( + + min(step / h.muon_momentum_warmup_steps, 1.0) + + if h.muon_momentum_warmup_steps > 0 + + else 1.0 + + ) + + muon_momentum = ( + + 1 - frac + + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + + for group in optimizers.optimizer_muon.param_groups: + + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(_clip_params, h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + if bijepax_opt is not None and bijepax_weight > 0.0: + bijepax_opt.step() + bijepax_opt.zero_grad(set_to_none=True) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + _live_state = base_model.state_dict(keep_vars=True) + ema_state = { + name: t.detach().float().clone() + for (name, t) in _live_state.items() + if not name.startswith("jepa_predictor.") + and not name.startswith("jepa_bwd_predictor.") + and name != "jepa_weight_current" + } + _ema_pairs = [ + (ema_state[name], t) + for (name, t) in _live_state.items() + if name in ema_state + ] + ema_decay = h.ema_decay + training_time_ms = 0.0 + forced_stop_step = int(os.environ.get("FORCE_STOP_STEP", "0")) + stop_after_step = forced_stop_step if forced_stop_step > 0 else None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale, frac) + with torch.no_grad(): + for ema_t, t in _ema_pairs: + ema_t.mul_(ema_decay).add_(t.detach(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + forced_stop_step <= 0 + and max_wallclock_ms is not None + and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and forced_stop_step <= 0 and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_predictor = None + log("jepa_lite_predictor_removed_before_serialize: true") + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + base_model.jepa_bwd_predictor = None + log("bijepa_backward_predictor_removed_before_serialize: true") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + global BOS_ID + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + quantize_only = os.environ.get("QUANTIZE_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_eval_only: true") + elif quantize_only: + log("QUANTIZE_ONLY=1 — skipping training, loading saved full-precision checkpoint") + log(f"quantize_only checkpoint: {h.model_path}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_quantize_only: true") + if BOS_ID is None: + BOS_ID = 1 + base_model = GPT(h).to(device).bfloat16() + state = torch.load(h.model_path, map_location="cpu") + base_model.load_state_dict(state, strict=True) + del state + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + else: + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_after_training: true") + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if os.environ.get("PREQUANT_ONLY", "0") == "1": + log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") + return + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + if os.environ.get("DISABLE_COMPILE", "0") == "1": + log("eval_compile:disabled_by_env") + compiled_model = eval_model + compiled_forward_logits = eval_model.forward_logits + else: + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if h.ttt_enabled or not h.ppm_mixer_enabled: + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + def _fwd_ttt_hint_inner(input_ids, target_ids, lora, hint_ids): + return ttt_model.forward_ttt( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + + _fwd_ttt_compiled_inner = None + _fwd_ttt_hint_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora, hint_ids=None): + nonlocal _fwd_ttt_compiled_inner, _fwd_ttt_hint_compiled_inner + if os.environ.get("DISABLE_COMPILE", "0") == "1": + if hint_ids is not None: + return _fwd_ttt_hint_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + return _fwd_ttt_inner(input_ids, target_ids, lora=lora) + if hint_ids is not None: + if _fwd_ttt_hint_compiled_inner is None: + _fwd_ttt_hint_compiled_inner = torch.compile( + _fwd_ttt_hint_inner, dynamic=True + ) + return _fwd_ttt_hint_compiled_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr * h.ttt_local_lr_mult, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + train_warmup_lens = [h.ttt_chunk_size] + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > h.ttt_chunk_size: + train_warmup_lens.append(train_window) + for ctx_len in train_warmup_lens: + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + if h.ngram_tilt_enabled: + ctx_len = h.ttt_eval_seq_len + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + hintw = torch.randint( + 0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64 + ) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + fwd_ttt_compiled(xw, yw, lora=wl, hint_ids=hintw) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + precomputed_hints = None + if h.ngram_tilt_enabled and h.ngram_hint_precompute_outside: + log("ngram_tilt:precomputing hints before TTT eval timer") + precomputed_hints = _compute_ngram_hints_for_val(h, val_data, log0=log) + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, + ttt_model, + device, + val_data, + forward_ttt_train=fwd_ttt_compiled, + precomputed_hints=precomputed_hints, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + if h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, ttt_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del ttt_model + elif h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, eval_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del eval_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 64 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +==================================================================================================== +train_shards: 80 +val_tokens: 47851520 +compile:disabled_by_env +model_params:35945673 +bijepax_config: weight=0.01 start=0.35 end=0.8 fwd=(4,) bwd=(4,) cycle=False head_dim=32 stride=64 lr=0.001 training_only=true +gptq:reserving 0s, effective=599500ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0076 val_bpb: 4.1159 +1/20000 train_loss: 9.0087 train_time: 0.0m tok/s: 4294285 +2/20000 train_loss: 12.8207 train_time: 0.0m tok/s: 3779390 +3/20000 train_loss: 10.2260 train_time: 0.0m tok/s: 3626892 +4/20000 train_loss: 8.6824 train_time: 0.0m tok/s: 3557519 +5/20000 train_loss: 7.9442 train_time: 0.0m tok/s: 3518508 +250/20000 train_loss: 2.9190 train_time: 1.0m tok/s: 3366377 +500/20000 train_loss: 2.5569 train_time: 1.9m tok/s: 3370798 +750/20000 train_loss: 2.7019 train_time: 2.9m tok/s: 3372986 +layer_loop:enabled step:900 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/20000 train_loss: 2.7481 train_time: 4.1m tok/s: 3211090 +1250/20000 train_loss: 2.5484 train_time: 5.5m tok/s: 2956874 +1500/20000 train_loss: 2.5173 train_time: 7.0m tok/s: 2808858 +1750/20000 train_loss: 2.5339 train_time: 8.5m tok/s: 2712218 +2000/20000 train_loss: 2.4834 train_time: 9.9m tok/s: 2643994 +2014/20000 val_loss: 2.4304 val_bpb: 1.1105 +stopping_early: wallclock_cap train_time: 599843ms step: 2014/20000 +peak memory allocated: 46676 MiB reserved: 52266 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.42847180 val_bpb:1.10964479 eval_time:9908ms +Serialized model: 135418111 bytes +Code size (uncompressed): 213292 bytes +Code size (compressed): 41999 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gate_int8_row: blocks.attn.attn_gate_w + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 99.9s +Serialized model quantized+pergroup: 15955181 bytes +Total submission size quantized+pergroup: 15997180 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.4s +eval_compile:disabled_by_env +diagnostic quantized val_loss:2.44116551 val_bpb:1.11544494 eval_time:10342ms +beginning PPM sliding eval +ppm_mixer val_bpb:0.97234287 eval_time:456845ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45118426 val_bpb:0.97234287 eval_time:502131ms diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed999.txt b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed999.txt new file mode 100644 index 0000000000..d608066694 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/full_seed999.txt @@ -0,0 +1,5251 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.99 + bijepa_backward_weight: 1.0 + bijepa_enabled: False + bijepax_bwd_hops: (4,) + bijepax_cycle: False + bijepax_enabled: True + bijepax_end_frac: 0.8 + bijepax_fwd_hops: (4,) + bijepax_head_dim: 32 + bijepax_lr: 0.001 + bijepax_start_frac: 0.35 + bijepax_stride: 64 + bijepax_weight: 0.01 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 512 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + jepa_lite_enabled: False + jepa_lite_end_frac: 0.85 + jepa_lite_frequency: 8 + jepa_lite_hidden: 128 + jepa_lite_horizon: 4 + jepa_lite_lr: 0.01 + jepa_lite_start_frac: 0.35 + jepa_lite_stride: 16 + jepa_lite_weight: 0.03 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: True + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2500 + ppm_dump_inputs: False + ppm_h: 0.999 + ppm_l: 0.18 + ppm_mixer_enabled: True + ppm_order: 5 + ppm_t: 0.8 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407 + scalar_lr: 0.02 + seed: 999 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 250 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: False + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_local_lr_mult: 0.75 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mask: no_qv + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_q_lora: False + ttt_train_window_tokens: 0 + ttt_v_lora: False + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_bytes_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +==================================================================================================== +Source code: +==================================================================================================== +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +class _BiJEPAXHead(nn.Module): + def __init__(self, d_model: int, head_dim: int): + super().__init__() + self.w1 = nn.Linear(d_model, head_dim, bias=False) + self.ln = nn.LayerNorm(head_dim) + self.w2 = nn.Linear(head_dim, d_model, bias=False) + nn.init.orthogonal_(self.w1.weight) + nn.init.zeros_(self.w2.weight) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.gelu(self.ln(self.w1(x.float())))) + + +class MultiDirectionalBiJEPAX(nn.Module): + """Training-only multi-hop bidirectional/cycle JEPA regularizer.""" + + def __init__( + self, + d_model: int, + fwd_hops: tuple[int, ...] = (1, 4, 16), + bwd_hops: tuple[int, ...] = (1, 4), + cycle: bool = True, + head_dim: int = 64, + ): + super().__init__() + self.fwd_hops = list(fwd_hops) + self.bwd_hops = list(bwd_hops) + self.cycle = cycle + self.fwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) + self.bwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in bwd_hops] + ) + self.cyc_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) if cycle else nn.ModuleList() + + @staticmethod + def _cos_loss(pred: Tensor, tgt: Tensor) -> Tensor: + pred = F.normalize(pred, dim=-1) + tgt = F.normalize(tgt.float().detach(), dim=-1) + return (1.0 - (pred * tgt).sum(dim=-1)).mean() + + def forward(self, hidden: Tensor, weight: float, stride: int) -> Tensor: + if weight <= 0.0 or hidden.size(1) < 2: + return hidden.new_zeros(()) + stride = max(1, int(stride)) + seq_len = hidden.size(1) + losses: list[Tensor] = [] + for i, hop in enumerate(self.fwd_hops): + if seq_len <= hop: + continue + src = hidden[:, : seq_len - hop : stride, :] + tgt = hidden[:, hop : hop + src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + src = src[:, :n, :] + tgt = tgt[:, :n, :] + losses.append(self._cos_loss(self.fwd_heads[i](src), tgt)) + if self.cycle and i < len(self.cyc_heads): + losses.append(self._cos_loss(self.cyc_heads[i](tgt.detach()), src)) + for i, hop in enumerate(self.bwd_hops): + if seq_len <= hop: + continue + src = hidden[:, hop::stride, :] + tgt = hidden[:, : src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + losses.append(self._cos_loss(self.bwd_heads[i](src[:, :n, :]), tgt[:, :n, :])) + if not losses: + return hidden.new_zeros(()) + return weight * (sum(losses) / len(losses)) + + +def bijepax_weight_at(frac: float, start: float, end: float, peak: float) -> float: + if frac < start or frac >= end: + return 0.0 + span = max(end - start, 1e-9) + ramp = min(span * 0.15, 0.08) + if frac < start + ramp: + return peak * (frac - start) / max(ramp, 1e-9) + if frac >= end - ramp: + return peak * (end - frac) / max(ramp, 1e-9) + return peak + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.85)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + # v13 is the sidecar-aware PPM lane. These defaults match the under-cap + # H100 package runs instead of the older TTT-first v12 defaults. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.1)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.99)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 512)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 80)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_local_lr_mult = float(os.environ.get("TTT_LOCAL_LR_MULT", 0.75)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_train_window_tokens = int(os.environ.get("TTT_TRAIN_WINDOW_TOKENS", 0)) + # V19: PR #1886 (renqianluo) + sunnypatneedi research log 2026-04-28 found that + # the Triton fused-CE kernel's fp32-accumulation interacts with warm-start LoRA-A + # to destabilize seeds 314/1337 at TTT_WEIGHT_DECAY=1.0. Raising the default to + # 2.0 prevents seed collapse without measurably moving stable seeds. + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 0.5)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.99)) + ttt_mask = os.environ.get("TTT_MASK", "no_qv").strip().lower() + _ttt_q_default = "1" + _ttt_v_default = "1" + if ttt_mask in ("", "all", "baseline_all"): + pass + elif ttt_mask == "no_q": + _ttt_q_default = "0" + elif ttt_mask == "no_v": + _ttt_v_default = "0" + elif ttt_mask == "no_qv": + _ttt_q_default = "0" + _ttt_v_default = "0" + else: + raise ValueError(f"Unsupported TTT_MASK={ttt_mask!r}") + ttt_q_lora = bool(int(os.environ.get("TTT_Q_LORA", _ttt_q_default))) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_v_lora = bool(int(os.environ.get("TTT_V_LORA", _ttt_v_default))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "pergroup") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 0.5)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2500)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 3)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 7)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 14.0)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 11.5)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "1"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "1"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "1"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 0.5)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + lqer_scope = os.environ.get("LQER_SCOPE", "all") + lqer_gain_select = bool(int(os.environ.get("LQER_GAIN_SELECT", "0"))) + awq_lite_enabled = bool(int(os.environ.get("AWQ_LITE_ENABLED", "1"))) + awq_lite_bits = int(os.environ.get("AWQ_LITE_BITS", "8")) + awq_lite_group_top_k = int(os.environ.get("AWQ_LITE_GROUP_TOP_K", "1")) + awq_lite_group_size = int(os.environ.get("AWQ_LITE_GROUP_SIZE", "64")) + # PR #1145/#1967 online n-gram tilt. This is a causal scoring overlay: + # prefix-only token/within-word/word experts propose one hint token, then + # the per-token NLL is adjusted with closed-form softmax renormalization. + ngram_tilt_enabled = bool(int(os.environ.get("NGRAM_TILT_ENABLED", "1"))) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + within_tau = float(os.environ.get("WITHIN_TAU", "0.450")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.750")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "0.650")) + word_boost = float(os.environ.get("WORD_BOOST", "0.750")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + ngram_hint_precompute_outside = bool(int(os.environ.get("NGRAM_HINT_PRECOMPUTE_OUTSIDE", "1"))) + ppm_mixer_enabled = bool(int(os.environ.get("PPM_MIXER_ENABLED", "1"))) + ppm_order = int(os.environ.get("PPM_ORDER", "5")) + ppm_h = float(os.environ.get("PPM_H", "0.995")) + ppm_l = float(os.environ.get("PPM_L", "0.20")) + ppm_t = float(os.environ.get("PPM_T", "0.80")) + ppm_dump_inputs = bool(int(os.environ.get("PPM_DUMP_INPUTS", "0"))) + # JEPA-Lite training regularizer from PR #2027. The predictor is never + # serialized; it only nudges hidden states during training. + jepa_lite_enabled = bool(int(os.environ.get("JEPA_LITE_ENABLED", "0"))) + jepa_lite_weight = float(os.environ.get("JEPA_LITE_WEIGHT", "0.03")) + jepa_lite_hidden = int(os.environ.get("JEPA_LITE_HIDDEN", "128")) + jepa_lite_horizon = int(os.environ.get("JEPA_LITE_HORIZON", "4")) + jepa_lite_stride = int(os.environ.get("JEPA_LITE_STRIDE", "16")) + jepa_lite_frequency = int(os.environ.get("JEPA_LITE_FREQUENCY", "8")) + jepa_lite_start_frac = float(os.environ.get("JEPA_LITE_START_FRAC", "0.35")) + jepa_lite_end_frac = float(os.environ.get("JEPA_LITE_END_FRAC", "0.85")) + jepa_lite_lr = float(os.environ.get("JEPA_LITE_LR", "0.01")) + bijepa_enabled = bool(int(os.environ.get("BIJEPA_ENABLED", "0"))) + bijepa_backward_weight = float(os.environ.get("BIJEPA_BACKWARD_WEIGHT", "1.0")) + bijepax_enabled = bool(int(os.environ.get("BIJEPAX_ENABLED", "0"))) + bijepax_weight = float(os.environ.get("BIJEPAX_WEIGHT", "0.04")) + bijepax_start_frac = float(os.environ.get("BIJEPAX_START_FRAC", "0.05")) + bijepax_end_frac = float(os.environ.get("BIJEPAX_END_FRAC", "0.80")) + bijepax_fwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_FWD_HOPS", "1,4,16").split(",") if x + ) + bijepax_bwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_BWD_HOPS", "1,4").split(",") if x + ) + bijepax_cycle = bool(int(os.environ.get("BIJEPAX_CYCLE", "1"))) + bijepax_head_dim = int(os.environ.get("BIJEPAX_HEAD_DIM", "64")) + bijepax_stride = int(os.environ.get("BIJEPAX_STRIDE", "8")) + bijepax_lr = float(os.environ.get("BIJEPAX_LR", "3e-4")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "1"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + if self.caseops_enabled: + self.base_bytes_lut = None + self.has_leading_space_lut = None + self.is_boundary_token_lut = None + else: + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._prefetch_queue = [] + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + targets = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + inputs.copy_(buf[:-1]) + targets.copy_(buf[1:]) + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + while len(self._prefetch_queue) < 2: + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + inputs, targets, cu_seqlens, max_seqlen = self._prefetch_queue.pop(0).result() + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def fused_log_softmax_dual_gather_kernel( + logits_ptr, + target_ids_ptr, + hint_ids_ptr, + log_p_y_out_ptr, + log_q_h_out_ptr, + n_rows, + n_cols, + block_cols: tl.constexpr, +): + row_idx = tl.program_id(0) + if row_idx >= n_rows: + return + target = tl.load(target_ids_ptr + row_idx) + hint = tl.load(hint_ids_ptr + row_idx) + row_offset = row_idx * n_cols + target_logit = tl.load(logits_ptr + row_offset + target).to(tl.float32) + hint_logit = tl.load(logits_ptr + row_offset + hint).to(tl.float32) + max_val = -float("inf") + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=-float("inf") + ).to(tl.float32) + max_val = tl.maximum(max_val, tl.max(vals, axis=0)) + sum_exp = tl.zeros((), dtype=tl.float32) + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=0.0 + ).to(tl.float32) + sum_exp += tl.sum(tl.where(mask, tl.exp(vals - max_val), 0.0), axis=0) + lse = max_val + tl.log(sum_exp) + tl.store(log_p_y_out_ptr + row_idx, target_logit - lse) + tl.store(log_q_h_out_ptr + row_idx, hint_logit - lse) + + +def fused_log_softmax_dual_gather(logits, target_ids, hint_ids): + bsz, seqlen, vocab = logits.shape + n_rows = bsz * seqlen + logits_flat = logits.reshape(n_rows, vocab).contiguous() + target_flat = target_ids.reshape(n_rows).contiguous() + hint_flat = hint_ids.reshape(n_rows).contiguous() + log_p_y_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + log_q_h_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + fused_log_softmax_dual_gather_kernel[(n_rows,)]( + logits_flat, + target_flat, + hint_flat, + log_p_y_out, + log_q_h_out, + n_rows, + vocab, + block_cols=1024, + num_warps=8, + ) + return log_p_y_out.reshape(bsz, seqlen), log_q_h_out.reshape(bsz, seqlen) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.3 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.3 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.3 * c0) + aux1 = tl.where(c1 > 0, c1, 0.3 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 256, 128, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.3).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.fused_ce_enabled = bool(h.fused_ce_enabled) + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + # V19: Asymmetric Logit Rescale (PR #1923 jorge-asenjo). + # Two learnable softcap scales applied on the EVAL path (forward_logits + + # forward_ttt). Init to logit_softcap so the layer is identity at step 0. + # Train path keeps the single fused softcap to preserve PR #1855 numerics. + self.asym_logit_enabled = bool(int(os.environ.get("ASYM_LOGIT_RESCALE", "1"))) + if self.asym_logit_enabled: + self.softcap_pos = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.softcap_neg = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.jepa_horizon = h.jepa_lite_horizon + self.jepa_stride = h.jepa_lite_stride + self.register_buffer( + "jepa_weight_current", + torch.zeros((), dtype=torch.float32), + persistent=False, + ) + self.jepa_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled + else None + ) + self.bijepa_enabled = bool(h.bijepa_enabled) + self.bijepa_backward_weight = float(h.bijepa_backward_weight) + self.jepa_bwd_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled and h.bijepa_enabled + else None + ) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self.tok_emb(input_ids) + # SmearGate (PR #1667). lam=0 + W=0 -> identity at init. + # Cross-doc leak fix: zero the prev-token smear at any position whose current token + # is BOS, so the BOS embedding starting doc N+1 in a packed stream is not + # contaminated by doc N's last token (audited issue on PR#1797 base). + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + return F.linear(hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def _apply_asym_softcap(self, logits): + # V19: Asymmetric softcap (PR #1923). Splits the logit_softcap scalar into + # learnable positive/negative branches. Score-first preserved: still a + # bounded, normalized post-projection nonlinearity feeding a standard + # softmax over the full vocab. + sp = self.softcap_pos.to(logits.dtype) + sn = self.softcap_neg.to(logits.dtype) + return torch.where(logits > 0, sp * torch.tanh(logits / sp), sn * torch.tanh(logits / sn)) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + if self.asym_logit_enabled: + return self._apply_asym_softcap(logits_proj) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def jepa_aux_loss(self, features): + if self.jepa_predictor is None: + return features.new_zeros(()) + horizon = max(1, int(self.jepa_horizon)) + stride = max(1, int(self.jepa_stride)) + if features.size(1) <= horizon: + return features.new_zeros(()) + src = features[:, :-horizon:stride, :] + tgt = features[:, horizon::stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + return features.new_zeros(()) + src = F.normalize(src[:, :n, :].float(), dim=-1) + tgt = F.normalize(tgt[:, :n, :].detach().float(), dim=-1) + pred = F.normalize(self.jepa_predictor(src), dim=-1) + loss = (1.0 - (pred * tgt).sum(dim=-1)).mean() + if self.jepa_bwd_predictor is not None: + bwd_src = F.normalize(features[:, horizon::stride, :][:, :n, :].float(), dim=-1) + bwd_tgt = F.normalize( + features[:, :-horizon:stride, :][:, :n, :].detach().float(), + dim=-1, + ) + bwd_pred = F.normalize(self.jepa_bwd_predictor(bwd_src), dim=-1) + bwd_loss = (1.0 - (bwd_pred * bwd_tgt).sum(dim=-1)).mean() + loss = loss + self.bijepa_backward_weight * bwd_loss + return loss * self.jepa_weight_current.to( + device=features.device, + dtype=features.dtype, + ) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss + + def forward_with_hidden(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden( + input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss, hidden + + def forward_ttt(self, input_ids, target_ids, lora, hint_ids=None): + x = self.tok_emb(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + # Cross-doc leak fix: see _forward_hidden comment. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + # V19: same asymmetric softcap on the TTT eval path. + if self.asym_logit_enabled: + logits = self._apply_asym_softcap(logits) + else: + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + if hint_ids is None: + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + if not logits.requires_grad: + log_p_y, log_q_h = fused_log_softmax_dual_gather( + logits, target_ids, hint_ids.clamp(min=0) + ) + return -log_p_y, log_q_h + ls = F.log_softmax(logits.float(), dim=-1) + log_p_y = ls.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) + log_q_h = ls.gather(-1, hint_ids.clamp(min=0).unsqueeze(-1)).squeeze(-1) + return -log_p_y, log_q_h + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__( + self, bsz, model, rank, + q_lora=True, k_lora=True, v_lora=True, mlp_lora=True, o_lora=True, + ): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if q_lora + else None + ) + self.v_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if v_lora + else None + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad) + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update, alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + jepa_params = ( + list(base_model.jepa_predictor.parameters()) + if getattr(base_model, "jepa_predictor", None) is not None + else [] + ) + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + jepa_params.extend(base_model.jepa_bwd_predictor.parameters()) + if jepa_params: + self.optimizer_jepa = torch.optim.AdamW( + [ + { + "params": jepa_params, + "lr": h.jepa_lite_lr, + "base_lr": h.jepa_lite_lr, + } + ], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=0.0, + fused=True, + ) + else: + self.optimizer_jepa = None + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + if self.optimizer_jepa is not None: + self.optimizers.append(self.optimizer_jepa) + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_params.extend(jepa_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + self._aux_stream = torch.cuda.Stream() + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self._aux_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._aux_stream): + self.optimizer_tok.step() + self.optimizer_scalar.step() + if self.optimizer_jepa is not None: + self.optimizer_jepa.step() + self.optimizer_muon.step() + torch.cuda.current_stream().wait_stream(self._aux_stream) + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + act_sumsq = {} + act_counts = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + x_sq = x.square().sum(dim=0) + x_count = x.shape[0] + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x_sq + act_counts[name] += x_count + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + y.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += y.square().sum(dim=0) + act_counts[name] += y.shape[0] + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + h_act.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += h_act.square().sum(dim=0) + act_counts[name] += h_act.shape[0] + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + act_stats = {} + for name, sumsq in act_sumsq.items(): + count = max(act_counts.get(name, 0), 1) + act_stats[name] = (sumsq / count).sqrt().cpu() + return hessians, act_stats + + +def gptq_quantize_weight( + w, + H, + clip_sigmas=3.0, + clip_range=63, + block_size=128, + protect_groups=None, + group_size=None, + protect_clip_range=None, +): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + H_flip = torch.flip(H, dims=(0, 1)) + L_flip = torch.linalg.cholesky(H_flip) + U = torch.flip(L_flip, dims=(0, 1)) + eye = torch.eye(H.shape[0], device=H.device, dtype=H.dtype) + Hinv = torch.linalg.solve_triangular(U, eye, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + protect_meta = None + protect_mask_perm = None + s_hi = None + sf_hi = None + if ( + protect_groups + and group_size is not None + and protect_clip_range is not None + and protect_clip_range > clip_range + ): + protect_mask = torch.zeros(cols, dtype=torch.bool) + starts = [] + for (start, end) in protect_groups: + if start < 0 or end > cols or end <= start: + continue + protect_mask[start:end] = True + starts.append(start) + if starts: + protect_mask_perm = protect_mask[perm] + s_hi = (clip_sigmas * row_std / protect_clip_range).clamp_min(1e-10).to( + torch.float16 + ) + sf_hi = s_hi.float() + protect_meta = { + "starts": torch.tensor(starts, dtype=torch.int16), + "size": int(group_size), + "s_hi": s_hi, + } + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + if protect_mask_perm is not None and bool(protect_mask_perm[i1 + j]): + q_col = torch.clamp( + torch.round(w_col / sf_hi), + -protect_clip_range, + protect_clip_range, + ) + w_recon = q_col.float() * sf_hi + else: + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + w_recon = q_col.float() * sf + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - w_recon) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s, protect_meta + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def _lqer_fit_quantized(E, h): + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + if r <= 0: + return None + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + A_hat = qA.float() * float(sA) + g_sz = qB.numel() // sB.numel() + B_hat = (qB.reshape(-1, g_sz).float() * sB.float().view(-1, 1)).reshape( + qB.shape + ) + return { + "kind": "asym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + A_hat = qA.float() * sA.float().view(-1, 1) + B_hat = qB.float() * sB.float().view(-1, 1) + return { + "kind": "sym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + + +def _awq_lite_group_candidates(w, act_rms, group_size): + cols = w.shape[1] + n_groups = cols // group_size + if n_groups <= 0: + return [] + weight_score = w.float().abs().mean(dim=0) + saliency = act_rms.float() * weight_score + cands = [] + for gi in range(n_groups): + start = gi * group_size + end = start + group_size + score = float(saliency[start:end].sum()) + cands.append((score, start, end)) + return cands + + +def gptq_mixed_quantize(state_dict, hessians, act_stats, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + awq_on = bool(getattr(h, "awq_lite_enabled", False)) + lqer_cands = {} + awq_selected = collections.defaultdict(list) + if awq_on: + awq_cands = [] + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if t.is_floating_point() and t.numel() > 65536 and name in act_stats: + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + if bits < h.awq_lite_bits: + for score, start, end in _awq_lite_group_candidates( + t, act_stats[name], h.awq_lite_group_size + ): + awq_cands.append((score, name, start, end)) + awq_cands.sort(key=lambda x: -x[0]) + for (_score, name, start, end) in awq_cands[: h.awq_lite_group_top_k]: + awq_selected[name].append((start, end)) + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + q, s, protect_meta = gptq_quantize_weight( + t, + hessians[name], + clip_sigmas=cs, + clip_range=clip_range, + protect_groups=awq_selected.get(name), + group_size=h.awq_lite_group_size if name in awq_selected else None, + protect_clip_range=(2 ** (h.awq_lite_bits - 1) - 1) + if name in awq_selected + else None, + ) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + W_q = q.float() * s.float().view(-1, 1) + if protect_meta is not None: + result[name + ".awqg_start"] = protect_meta["starts"] + result[name + ".awqg_s_hi"] = protect_meta["s_hi"] + result[name + ".awqg_size"] = torch.tensor( + protect_meta["size"], dtype=torch.int16 + ) + meta[name] = meta[name] + f"+awqgrpint{h.awq_lite_bits}" + gsz = protect_meta["size"] + for start in protect_meta["starts"].tolist(): + W_q[:, start : start + gsz] = ( + q[:, start : start + gsz].float() + * protect_meta["s_hi"].float().view(-1, 1) + ) + if lqer_on: + # LQER is fit on top of the fully realized GPTQ base, which already + # includes any higher-precision AWQ-protected groups. + scope = str(getattr(h, "lqer_scope", "all")).lower() + scope_ok = ( + scope == "all" + or (scope == "mlp" and ".mlp." in name) + or (scope == "attn" and ".attn." in name) + or (scope == "embed" and "tok_emb" in name) + ) + if scope_ok: + E = t.float() - W_q + err_norm = float(E.norm()) + if err_norm > 0: + lqer_cands[name] = (E, err_norm) + if lqer_on and lqer_cands: + if bool(getattr(h, "lqer_gain_select", False)): + scored = [] + for (name, (E, base_err)) in lqer_cands.items(): + fit = _lqer_fit_quantized(E, h) + if fit is None: + continue + new_err = float((E - fit["delta"]).norm()) + gain = base_err - new_err + if gain > 0: + scored.append((gain, name, fit)) + scored.sort(key=lambda x: -x[0]) + for (_gain, name, fit) in scored[: h.lqer_top_k]: + if fit["kind"] == "asym": + result[name + ".lqA_a"] = fit["qA"] + result[name + ".lqAs_a"] = fit["sA"] + result[name + ".lqB_a"] = fit["qB"] + result[name + ".lqBs_a"] = fit["sB"] + meta[name] = meta[name] + "+lqer_asym" + else: + result[name + ".lqA"] = fit["qA"] + result[name + ".lqAs"] = fit["sA"] + result[name + ".lqB"] = fit["qB"] + result[name + ".lqBs"] = fit["sB"] + meta[name] = meta[name] + "+lqer" + else: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "awqgrpint" in info: + starts = result[name + ".awqg_start"].tolist() + s_hi = result[name + ".awqg_s_hi"].float() + gsz = int(result[name + ".awqg_size"].item()) + for start in starts: + W[:, start : start + gsz] = ( + q[:, start : start + gsz].float() * s_hi.view(-1, 1) + ) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +# ── Per-group lrzip compression (ported from PR#1586 via PR#1667/1729) ──────── + +_GROUP_ORDER = [ + "_tok_emb.weight.q", + "attn.c_k.weight.q", "attn.c_q.weight.q", + "attn.c_v.weight.q", "attn.proj.weight.q", + "mlp.fc.weight.q", "mlp.proj.weight.q", +] +_SIMSORT_KEYS = {"_tok_emb.weight.q", "attn.c_q.weight.q", "mlp.fc.weight.q"} +_PACK_MAGIC = b"PGRP" + + +def _similarity_sort_l1(matrix): + import numpy as _np + n = matrix.shape[0] + used = _np.zeros(n, dtype=bool) + order = [0] + used[0] = True + cur = matrix[0].astype(_np.float32) + for _ in range(n - 1): + dists = _np.sum(_np.abs(matrix[~used].astype(_np.float32) - cur), axis=1) + unused = _np.where(~used)[0] + best = unused[_np.argmin(dists)] + order.append(best) + used[best] = True + cur = matrix[best].astype(_np.float32) + return _np.array(order, dtype=_np.uint16) + + +def _lrzip_compress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.bin") + out = f"{inp}.lrz" + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-z", "-L", "9", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _lrzip_decompress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.lrz") + out = os.path.join(tmpdir, f"{label}.bin") + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-d", "-f", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _pack_streams(streams): + import struct + n = len(streams) + hdr = _PACK_MAGIC + struct.pack("", p) + if m: + return bytes([int(m.group(1), 16)]) + return (" " + p[1:]).encode() if p.startswith("▁") else p.encode() + + +def _ppm_mixture_bpb(tgt_np, lp_np, sp, O=4, H=0.9, L_=0.05, T=0.9, token_byte_lens_np=None): + V = sp.vocab_size() + piece_bytes = [None] * V + piece_lens = np.zeros(V, dtype=np.int32) + for i in range(V): + b = _ppm_piece_bytes(sp, i) + piece_bytes[i] = b + piece_lens[i] = len(b) + if token_byte_lens_np is None: + per_tok_len = piece_lens[tgt_np] + bs = b''.join(piece_bytes[int(t)] for t in tgt_np) + kept_lp = lp_np + else: + chunks = [] + kept_lp_parts = [] + lens_parts = [] + for t, lp, side_len in zip(tgt_np, lp_np, token_byte_lens_np): + side_len = int(side_len) + if side_len <= 0: + continue + b = piece_bytes[int(t)] + if not b: + continue + if len(b) > side_len: + b = b[:side_len] + elif len(b) < side_len: + b = b + b[-1:] * (side_len - len(b)) + chunks.append(b) + kept_lp_parts.append(float(lp)) + lens_parts.append(side_len) + if not chunks: + return float("inf") + bs = b''.join(chunks) + per_tok_len = np.asarray(lens_parts, dtype=np.int32) + kept_lp = np.asarray(kept_lp_parts, dtype=np.float64) + N = len(bs) + rep_lp = np.repeat(kept_lp.astype(np.float64), per_tok_len) + rep_len = np.repeat(per_tok_len.astype(np.float64), per_tok_len) + nlp = np.where(rep_len > 0, rep_lp / rep_len, 0.0) + tabs = [dict() for _ in range(O + 1)] + plp = np.empty(N, dtype=np.float64) + cf = np.empty(N, dtype=np.float64) + LN256 = math.log(1 / 256) + log_ = math.log + h_ctx = b'' + for i in range(N): + x = bs[i] + if i == 0: + plp[i] = LN256 + cf[i] = 1 / 256 + else: + esc = 1.0 + pf = 0.0 + cf_mx = 0 + cf_tot = 256 + cf_seen = False + lim = O if i > O else i + for o in range(lim, -1, -1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + continue + if not cf_seen: + cf_mx = e[1] + cf_tot = e[0] + cf_seen = True + tot = e[0] + d = e[2] + c = d.get(x, 0) + if c > 0: + pf = esc * (2 * c - 1) / (2 * tot) + break + esc *= len(d) / (2 * tot) + else: + pf = esc / 256 + if pf < 1e-20: + pf = 1e-20 + plp[i] = log_(pf) + cf[i] = (cf_mx / cf_tot) if cf_seen else 1 / 256 + for o in range(O + 1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + tabs[o][k] = [1, 1, {x: 1}] + else: + e[0] += 1 + d = e[2] + cnt = d.get(x, 0) + 1 + d[x] = cnt + if cnt > e[1]: + e[1] = cnt + h_ctx = (h_ctx + bytes([x]))[-O:] + nn_prob = np.exp(nlp) + ppm_prob = np.exp(plp) + + def _mix_bpb(Hv, Lv, Tv): + lam_v = np.where(cf > Tv, Lv, Hv) + pm_v = lam_v * nn_prob + (1 - lam_v) * ppm_prob + return float(-np.log2(np.maximum(pm_v, 1e-300)).sum() / N) + + default_bpb = _mix_bpb(H, L_, T) + if os.environ.get("PPM_SWEEP_GRID", "0") == "1": + hs = [float(x) for x in os.environ.get("PPM_SWEEP_HS", str(H)).split(",") if x.strip()] + ls = [float(x) for x in os.environ.get("PPM_SWEEP_LS", str(L_)).split(",") if x.strip()] + ts = [float(x) for x in os.environ.get("PPM_SWEEP_TS", str(T)).split(",") if x.strip()] + combo_count = len(hs) * len(ls) * len(ts) + max_combos = int(os.environ.get("PPM_SWEEP_MAX_COMBOS", "256")) + if combo_count > max_combos and os.environ.get("PPM_SWEEP_ALLOW_SLOW", "0") != "1": + log( + f"ppm_sweep skipped: combos={combo_count} max={max_combos}; " + "dump inputs and replay offline, or set PPM_SWEEP_ALLOW_SLOW=1" + ) + return default_bpb + best = (default_bpb, H, L_, T) + for Hv in hs: + for Lv in ls: + for Tv in ts: + bpb = _mix_bpb(Hv, Lv, Tv) + if bpb < best[0]: + best = (bpb, Hv, Lv, Tv) + log( + f"ppm_sweep best_bpb:{best[0]:.8f} H={best[1]} L={best[2]} T={best[3]} " + f"default_bpb:{default_bpb:.8f}" + ) + if os.environ.get("PPM_SWEEP_APPLY", "0") == "1": + return best[0] + return default_bpb + + +def eval_val_ppm_sliding(h, device, val_data, model, batch_seqs=32): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + model.eval() + seq_len = h.eval_seq_len + stride = h.eval_stride + context_size = seq_len - stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + tga_local = [] + lpa_local = [] + bla_local = [] + fwd_fn = model.module.forward_logits if hasattr(model, 'module') else model.forward_logits + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = fwd_fn(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction='none', + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + if val_data.val_bytes is not None: + tb = val_data.val_bytes[ws + s + 1: ws + wlen + 1].to(device=device, dtype=torch.float64) + else: + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + tga_local.append(tgt.cpu().to(torch.int64)) + lpa_local.append((-scored_nll).cpu().to(torch.float64)) + bla_local.append(tb.cpu().to(torch.int32)) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss, val_bpb = _loss_bpb(loss_sum, token_count, byte_count) + if h.ppm_mixer_enabled: + tga_local_cat = torch.cat(tga_local) if tga_local else torch.zeros(0, dtype=torch.int64) + lpa_local_cat = torch.cat(lpa_local) if lpa_local else torch.zeros(0, dtype=torch.float64) + bla_local_cat = torch.cat(bla_local) if bla_local else torch.zeros(0, dtype=torch.int32) + if dist.is_available() and dist.is_initialized(): + local_size = torch.tensor([tga_local_cat.numel()], dtype=torch.int64, device=device) + sizes = [torch.zeros(1, dtype=torch.int64, device=device) for _ in range(h.world_size)] + dist.all_gather(sizes, local_size) + sizes_list = [int(s.item()) for s in sizes] + max_size = max(sizes_list) if sizes_list else 0 + tga_pad = torch.zeros(max_size, dtype=torch.int64, device=device) + lpa_pad = torch.zeros(max_size, dtype=torch.float64, device=device) + bla_pad = torch.zeros(max_size, dtype=torch.int32, device=device) + tga_pad[:tga_local_cat.numel()] = tga_local_cat.to(device) + lpa_pad[:lpa_local_cat.numel()] = lpa_local_cat.to(device) + bla_pad[:bla_local_cat.numel()] = bla_local_cat.to(device) + if h.rank == 0: + gather_t = [torch.zeros(max_size, dtype=torch.int64, device=device) for _ in range(h.world_size)] + gather_l = [torch.zeros(max_size, dtype=torch.float64, device=device) for _ in range(h.world_size)] + gather_b = [torch.zeros(max_size, dtype=torch.int32, device=device) for _ in range(h.world_size)] + else: + gather_t = None + gather_l = None + gather_b = None + dist.gather(tga_pad, gather_t, dst=0) + dist.gather(lpa_pad, gather_l, dst=0) + dist.gather(bla_pad, gather_b, dst=0) + if h.rank == 0: + tga_full = torch.cat([gather_t[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + lpa_full = torch.cat([gather_l[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + bla_full = torch.cat([gather_b[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_full.astype(np.int64), + logp=lpa_full.astype(np.float64), + byte_lens=bla_full.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_full, lpa_full, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_full) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_full.size} N_sidecar_bytes={int(bla_full.sum())}') + val_bpb = mixer_bpb + else: + tga_np = tga_local_cat.numpy() + lpa_np = lpa_local_cat.numpy() + bla_np = bla_local_cat.numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_np.astype(np.int64), + logp=lpa_np.astype(np.float64), + byte_lens=bla_np.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_np, lpa_np, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_np) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_np.size} N_sidecar_bytes={int(bla_np.sum())}') + val_bpb = mixer_bpb + model.train() + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( + device=device, dtype=torch.int32, non_blocking=True + ) + val_byte_count += sidecar_slice.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def jepa_weight_at(h, step, frac): + if not h.jepa_lite_enabled: + return 0.0 + freq = max(1, h.jepa_lite_frequency) + if step % freq != 0 or frac < h.jepa_lite_start_frac: + return 0.0 + weight = h.jepa_lite_weight + if frac > h.jepa_lite_end_frac: + weight *= max( + 0.0, + (1.0 - frac) / max(1e-09, 1.0 - h.jepa_lite_end_frac), + ) + return weight + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + if os.environ.get("VAL_DOC_PREFIX_ONLY", "0") == "1": + return doc_entries[:sample_n] + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def _compute_ngram_hints_for_val(h, val_data, log0=print): + if not getattr(h, "ngram_tilt_enabled", False): + return None + from online_ngram_tilt import build_hints_for_targets + + all_tokens = val_data.val_tokens + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + max_targets = int(os.environ.get("NGRAM_HINT_MAX_TARGETS", "0")) + target_count = targets_np_all.shape[0] + if max_targets > 0: + targets_np = targets_np_all[: min(max_targets, target_count)] + else: + targets_np = targets_np_all + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log0, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + if hint_global.numel() < target_count: + padded_hint = torch.zeros(target_count, dtype=torch.int64) + padded_gate = torch.zeros(target_count, dtype=torch.bool) + padded_boost = torch.zeros(target_count, dtype=torch.float32) + padded_hint[: hint_global.numel()] = hint_global + padded_gate[: gate_global.numel()] = gate_global + padded_boost[: boost_global.numel()] = boost_global + hint_global, gate_global, boost_global = padded_hint, padded_gate, padded_boost + log0( + f"ngram_tilt:precompute_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={hint_global.numel()} computed_targets={targets_np.shape[0]}" + ) + return hint_global, gate_global, boost_global + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, precomputed_hints=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + ngram_hint_global = None + ngram_gate_global = None + ngram_boost_global = None + if precomputed_hints is not None: + ngram_hint_global, ngram_gate_global, ngram_boost_global = precomputed_hints + log( + "ngram_tilt:using_precomputed_hints " + f"total_targets={ngram_hint_global.numel()}" + ) + elif getattr(h, "ngram_tilt_enabled", False): + ngram_hint_global, ngram_gate_global, ngram_boost_global = _compute_ngram_hints_for_val( + h, val_data, log0=log + ) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + local_lr = h.ttt_lora_lr * h.ttt_local_lr_mult + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=local_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=local_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + hint_ids_gpu = None + gate_mask_gpu = None + boost_gpu = None + if ngram_hint_global is not None: + hint_idx_cpu = ( + tok_starts.unsqueeze(1) + col_idx[:context_size].unsqueeze(0) + ).clamp_(min=0, max=ngram_hint_global.numel() - 1) + hint_ids_gpu = ngram_hint_global[hint_idx_cpu].to( + device=device, dtype=torch.int64, non_blocking=True + ) + gate_mask_gpu = ngram_gate_global[hint_idx_cpu].to( + device=device, non_blocking=True + ) + boost_gpu = ngram_boost_global[hint_idx_cpu].to( + device=device, dtype=torch.float32, non_blocking=True + ) + hint_ids_gpu = torch.where(valid, hint_ids_gpu, torch.zeros_like(hint_ids_gpu)) + gate_mask_gpu = gate_mask_gpu & valid + log_q_hint = None + if hint_ids_gpu is not None: + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss, log_q_hint = forward_ttt_train( + x, y, lora=cur_lora, hint_ids=hint_ids_gpu + ) + else: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + if hint_ids_gpu is not None and log_q_hint is not None: + from online_ngram_tilt import apply_tilt_to_ptl_torch_fast + + scored_loss = apply_tilt_to_ptl_torch_fast( + ptl=per_tok_loss, + log_q_hint=log_q_hint, + target_ids=y, + hint_ids=hint_ids_gpu, + gate_mask=gate_mask_gpu, + boost=boost_gpu, + ) + else: + scored_loss = per_tok_loss + with torch.no_grad(): + _accumulate_bpb( + scored_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if scored_loss is not per_tok_loss: + del scored_loss + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + train_x, train_y = x, y + train_chunk_offset = chunk_offset + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > 0 and context_size > max(train_window, chunk_size): + train_window = max(train_window, chunk_size) + train_end = min(context_size, chunk_offset + chunk_size) + train_start = max(0, train_end - train_window) + train_x = x[:, train_start:train_end].contiguous() + train_y = y[:, train_start:train_end].contiguous() + train_chunk_offset = chunk_offset - train_start + for gi in range(h.ttt_grad_steps): + if hint_ids_gpu is not None or gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + train_per_tok_loss = forward_ttt_train( + train_x, train_y, lora=cur_lora + ) + else: + train_per_tok_loss = per_tok_loss + per_doc = train_per_tok_loss[ + :, train_chunk_offset : train_chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + if train_per_tok_loss is not per_tok_loss: + del train_per_tok_loss + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compile_enabled = os.environ.get("DISABLE_COMPILE", "0") != "1" + if compile_enabled: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + else: + log("compile:disabled_by_env") + compiled_model = base_model + compiled_forward_logits = base_model.forward_logits + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + if h.jepa_lite_enabled: + log( + "jepa_lite_config: " + f"weight={h.jepa_lite_weight} hidden={h.jepa_lite_hidden} " + f"horizon={h.jepa_lite_horizon} stride={h.jepa_lite_stride} " + f"frequency={h.jepa_lite_frequency} start_frac={h.jepa_lite_start_frac} " + f"end_frac={h.jepa_lite_end_frac} lr={h.jepa_lite_lr}" + ) + if h.bijepa_enabled: + log( + "bijepa_config: " + f"backward_weight={h.bijepa_backward_weight} " + "dual_predictor=true training_only=true" + ) + bijepax_module = None + bijepax_opt = None + forward_with_hidden = None + if h.bijepax_enabled: + bijepax_module = MultiDirectionalBiJEPAX( + d_model=h.model_dim, + fwd_hops=h.bijepax_fwd_hops, + bwd_hops=h.bijepax_bwd_hops, + cycle=h.bijepax_cycle, + head_dim=h.bijepax_head_dim, + ).to(device).bfloat16() + bijepax_opt = torch.optim.Adam( + bijepax_module.parameters(), lr=h.bijepax_lr, betas=(0.9, 0.99) + ) + forward_with_hidden = ( + torch.compile(base_model.forward_with_hidden, dynamic=False, fullgraph=True) + if compile_enabled + else base_model.forward_with_hidden + ) + log( + "bijepax_config: " + f"weight={h.bijepax_weight} start={h.bijepax_start_frac} " + f"end={h.bijepax_end_frac} fwd={h.bijepax_fwd_hops} " + f"bwd={h.bijepax_bwd_hops} cycle={h.bijepax_cycle} " + f"head_dim={h.bijepax_head_dim} stride={h.bijepax_stride} " + f"lr={h.bijepax_lr} training_only=true" + ) + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + _clip_params = [p for p in base_model.parameters() if p.requires_grad] + def step_fn(step, lr_scale, frac_now=0.0): + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_weight_current.fill_(jepa_weight_at(h, step, frac_now)) + bijepax_weight = ( + bijepax_weight_at( + frac_now, + h.bijepax_start_frac, + h.bijepax_end_frac, + h.bijepax_weight, + ) + if bijepax_module is not None + else 0.0 + ) + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if bijepax_module is not None and bijepax_weight > 0.0: + ce_loss, hidden = forward_with_hidden( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + loss = ce_loss + bijepax_module( + hidden, bijepax_weight, h.bijepax_stride + ) + else: + loss = model( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + if step <= h.muon_momentum_warmup_steps: + + frac = ( + + min(step / h.muon_momentum_warmup_steps, 1.0) + + if h.muon_momentum_warmup_steps > 0 + + else 1.0 + + ) + + muon_momentum = ( + + 1 - frac + + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + + for group in optimizers.optimizer_muon.param_groups: + + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(_clip_params, h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + if bijepax_opt is not None and bijepax_weight > 0.0: + bijepax_opt.step() + bijepax_opt.zero_grad(set_to_none=True) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + _live_state = base_model.state_dict(keep_vars=True) + ema_state = { + name: t.detach().float().clone() + for (name, t) in _live_state.items() + if not name.startswith("jepa_predictor.") + and not name.startswith("jepa_bwd_predictor.") + and name != "jepa_weight_current" + } + _ema_pairs = [ + (ema_state[name], t) + for (name, t) in _live_state.items() + if name in ema_state + ] + ema_decay = h.ema_decay + training_time_ms = 0.0 + forced_stop_step = int(os.environ.get("FORCE_STOP_STEP", "0")) + stop_after_step = forced_stop_step if forced_stop_step > 0 else None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale, frac) + with torch.no_grad(): + for ema_t, t in _ema_pairs: + ema_t.mul_(ema_decay).add_(t.detach(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + forced_stop_step <= 0 + and max_wallclock_ms is not None + and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and forced_stop_step <= 0 and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_predictor = None + log("jepa_lite_predictor_removed_before_serialize: true") + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + base_model.jepa_bwd_predictor = None + log("bijepa_backward_predictor_removed_before_serialize: true") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + global BOS_ID + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + quantize_only = os.environ.get("QUANTIZE_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_eval_only: true") + elif quantize_only: + log("QUANTIZE_ONLY=1 — skipping training, loading saved full-precision checkpoint") + log(f"quantize_only checkpoint: {h.model_path}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_quantize_only: true") + if BOS_ID is None: + BOS_ID = 1 + base_model = GPT(h).to(device).bfloat16() + state = torch.load(h.model_path, map_location="cpu") + base_model.load_state_dict(state, strict=True) + del state + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + else: + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_after_training: true") + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if os.environ.get("PREQUANT_ONLY", "0") == "1": + log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") + return + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + if os.environ.get("DISABLE_COMPILE", "0") == "1": + log("eval_compile:disabled_by_env") + compiled_model = eval_model + compiled_forward_logits = eval_model.forward_logits + else: + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if h.ttt_enabled or not h.ppm_mixer_enabled: + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + def _fwd_ttt_hint_inner(input_ids, target_ids, lora, hint_ids): + return ttt_model.forward_ttt( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + + _fwd_ttt_compiled_inner = None + _fwd_ttt_hint_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora, hint_ids=None): + nonlocal _fwd_ttt_compiled_inner, _fwd_ttt_hint_compiled_inner + if os.environ.get("DISABLE_COMPILE", "0") == "1": + if hint_ids is not None: + return _fwd_ttt_hint_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + return _fwd_ttt_inner(input_ids, target_ids, lora=lora) + if hint_ids is not None: + if _fwd_ttt_hint_compiled_inner is None: + _fwd_ttt_hint_compiled_inner = torch.compile( + _fwd_ttt_hint_inner, dynamic=True + ) + return _fwd_ttt_hint_compiled_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr * h.ttt_local_lr_mult, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + train_warmup_lens = [h.ttt_chunk_size] + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > h.ttt_chunk_size: + train_warmup_lens.append(train_window) + for ctx_len in train_warmup_lens: + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + if h.ngram_tilt_enabled: + ctx_len = h.ttt_eval_seq_len + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + hintw = torch.randint( + 0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64 + ) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + fwd_ttt_compiled(xw, yw, lora=wl, hint_ids=hintw) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + precomputed_hints = None + if h.ngram_tilt_enabled and h.ngram_hint_precompute_outside: + log("ngram_tilt:precomputing hints before TTT eval timer") + precomputed_hints = _compute_ngram_hints_for_val(h, val_data, log0=log) + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, + ttt_model, + device, + val_data, + forward_ttt_train=fwd_ttt_compiled, + precomputed_hints=precomputed_hints, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + if h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, ttt_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del ttt_model + elif h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, eval_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del eval_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 64 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +==================================================================================================== +train_shards: 80 +val_tokens: 47851520 +compile:disabled_by_env +model_params:35945673 +bijepax_config: weight=0.01 start=0.35 end=0.8 fwd=(4,) bwd=(4,) cycle=False head_dim=32 stride=64 lr=0.001 training_only=true +gptq:reserving 0s, effective=599500ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0149 val_bpb: 4.1192 +1/20000 train_loss: 9.0152 train_time: 0.0m tok/s: 4307036 +2/20000 train_loss: 12.8530 train_time: 0.0m tok/s: 3767064 +3/20000 train_loss: 10.1713 train_time: 0.0m tok/s: 3626995 +4/20000 train_loss: 8.6216 train_time: 0.0m tok/s: 3559304 +5/20000 train_loss: 7.9225 train_time: 0.0m tok/s: 3520003 +250/20000 train_loss: 2.9185 train_time: 1.0m tok/s: 3365671 +500/20000 train_loss: 2.5589 train_time: 1.9m tok/s: 3370654 +750/20000 train_loss: 2.7064 train_time: 2.9m tok/s: 3370035 +layer_loop:enabled step:899 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/20000 train_loss: 2.7514 train_time: 4.1m tok/s: 3206994 +1250/20000 train_loss: 2.5502 train_time: 5.5m tok/s: 2953990 +1500/20000 train_loss: 2.5201 train_time: 7.0m tok/s: 2806608 +1750/20000 train_loss: 2.5365 train_time: 8.5m tok/s: 2710429 +2000/20000 train_loss: 2.4844 train_time: 9.9m tok/s: 2642524 +2013/20000 val_loss: 2.4352 val_bpb: 1.1127 +stopping_early: wallclock_cap train_time: 599821ms step: 2013/20000 +peak memory allocated: 46676 MiB reserved: 52266 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.43314506 val_bpb:1.11178015 eval_time:9911ms +Serialized model: 135418111 bytes +Code size (uncompressed): 213292 bytes +Code size (compressed): 41999 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gate_int8_row: blocks.attn.attn_gate_w + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 100.8s +Serialized model quantized+pergroup: 15955594 bytes +Total submission size quantized+pergroup: 15997593 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.3s +eval_compile:disabled_by_env +diagnostic quantized val_loss:2.44582432 val_bpb:1.11757370 eval_time:11393ms +beginning PPM sliding eval +ppm_mixer val_bpb:0.97373767 eval_time:451054ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45502055 val_bpb:0.97373767 eval_time:496384ms diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/submission.json b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/submission.json new file mode 100644 index 0000000000..ea7d527316 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/submission.json @@ -0,0 +1,53 @@ +{ + "author": "NewyorkDev, Claude, and Codex", + "github_id": "NewyorkDev", + "name": "BIJEPAX-lite JEPA + SP8192 CaseOps PPM", + "blurb": "A Claude-designed, JEPA-inspired training-only hidden-state regularizer layered onto the SP8192 CaseOps + per-group compression + PPM sliding evaluator. Predictor heads are separate from the base model, optimized only during training, and absent from the serialized artifact. Three full seeds complete under the 600s train/eval budgets and strict 16,000,000 byte artifact cap. Final 3-seed mean ppm_sliding val_bpb is 0.97271454.", + "date": "2026-05-01T03:39:42Z", + "val_bpb": 0.97271454, + "val_bpb_std": 0.00089703, + "val_bpb_by_seed": { + "42": 0.97234287, + "314": 0.97206308, + "999": 0.97373767 + }, + "bytes_total_max": 15999539, + "bytes_total_by_seed": { + "42": 15997180, + "314": 15999539, + "999": 15997593 + }, + "eval_time_ms_max": 502131, + "eval_time_ms_by_seed": { + "42": 502131, + "314": 499038, + "999": 496384 + }, + "train_time_ms_by_seed": { + "42": 599843, + "314": 599586, + "999": 599821 + }, + "config": { + "DISABLE_COMPILE": "1", + "CASEOPS_ENABLED": "1", + "VOCAB_SIZE": "8192", + "TTT_ENABLED": "0", + "PPM_MIXER_ENABLED": "1", + "PPM_ORDER": "5", + "PPM_H": "0.999", + "PPM_L": "0.18", + "PPM_T": "0.80", + "LQER_TOP_K": "1", + "BIJEPAX_ENABLED": "1", + "BIJEPAX_WEIGHT": "0.01", + "BIJEPAX_START_FRAC": "0.35", + "BIJEPAX_END_FRAC": "0.80", + "BIJEPAX_FWD_HOPS": "4", + "BIJEPAX_BWD_HOPS": "4", + "BIJEPAX_CYCLE": "0", + "BIJEPAX_HEAD_DIM": "32", + "BIJEPAX_STRIDE": "64", + "BIJEPAX_LR": "0.001" + } +} diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_gpt.py b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_gpt.py new file mode 100644 index 0000000000..b566e2b508 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_gpt.py @@ -0,0 +1,5002 @@ +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +class _BiJEPAXHead(nn.Module): + def __init__(self, d_model: int, head_dim: int): + super().__init__() + self.w1 = nn.Linear(d_model, head_dim, bias=False) + self.ln = nn.LayerNorm(head_dim) + self.w2 = nn.Linear(head_dim, d_model, bias=False) + nn.init.orthogonal_(self.w1.weight) + nn.init.zeros_(self.w2.weight) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.gelu(self.ln(self.w1(x.float())))) + + +class MultiDirectionalBiJEPAX(nn.Module): + """Training-only multi-hop bidirectional/cycle JEPA regularizer.""" + + def __init__( + self, + d_model: int, + fwd_hops: tuple[int, ...] = (1, 4, 16), + bwd_hops: tuple[int, ...] = (1, 4), + cycle: bool = True, + head_dim: int = 64, + ): + super().__init__() + self.fwd_hops = list(fwd_hops) + self.bwd_hops = list(bwd_hops) + self.cycle = cycle + self.fwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) + self.bwd_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in bwd_hops] + ) + self.cyc_heads = nn.ModuleList( + [_BiJEPAXHead(d_model, head_dim) for _ in fwd_hops] + ) if cycle else nn.ModuleList() + + @staticmethod + def _cos_loss(pred: Tensor, tgt: Tensor) -> Tensor: + pred = F.normalize(pred, dim=-1) + tgt = F.normalize(tgt.float().detach(), dim=-1) + return (1.0 - (pred * tgt).sum(dim=-1)).mean() + + def forward(self, hidden: Tensor, weight: float, stride: int) -> Tensor: + if weight <= 0.0 or hidden.size(1) < 2: + return hidden.new_zeros(()) + stride = max(1, int(stride)) + seq_len = hidden.size(1) + losses: list[Tensor] = [] + for i, hop in enumerate(self.fwd_hops): + if seq_len <= hop: + continue + src = hidden[:, : seq_len - hop : stride, :] + tgt = hidden[:, hop : hop + src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + src = src[:, :n, :] + tgt = tgt[:, :n, :] + losses.append(self._cos_loss(self.fwd_heads[i](src), tgt)) + if self.cycle and i < len(self.cyc_heads): + losses.append(self._cos_loss(self.cyc_heads[i](tgt.detach()), src)) + for i, hop in enumerate(self.bwd_hops): + if seq_len <= hop: + continue + src = hidden[:, hop::stride, :] + tgt = hidden[:, : src.size(1) * stride : stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + continue + losses.append(self._cos_loss(self.bwd_heads[i](src[:, :n, :]), tgt[:, :n, :])) + if not losses: + return hidden.new_zeros(()) + return weight * (sum(losses) / len(losses)) + + +def bijepax_weight_at(frac: float, start: float, end: float, peak: float) -> float: + if frac < start or frac >= end: + return 0.0 + span = max(end - start, 1e-9) + ramp = min(span * 0.15, 0.08) + if frac < start + ramp: + return peak * (frac - start) / max(ramp, 1e-9) + if frac >= end - ramp: + return peak * (end - frac) / max(ramp, 1e-9) + return peak + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.85)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + # v13 is the sidecar-aware PPM lane. These defaults match the under-cap + # H100 package runs instead of the older TTT-first v12 defaults. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.1)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.99)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 512)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 80)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_local_lr_mult = float(os.environ.get("TTT_LOCAL_LR_MULT", 0.75)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_train_window_tokens = int(os.environ.get("TTT_TRAIN_WINDOW_TOKENS", 0)) + # V19: PR #1886 (renqianluo) + sunnypatneedi research log 2026-04-28 found that + # the Triton fused-CE kernel's fp32-accumulation interacts with warm-start LoRA-A + # to destabilize seeds 314/1337 at TTT_WEIGHT_DECAY=1.0. Raising the default to + # 2.0 prevents seed collapse without measurably moving stable seeds. + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 0.5)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.99)) + ttt_mask = os.environ.get("TTT_MASK", "no_qv").strip().lower() + _ttt_q_default = "1" + _ttt_v_default = "1" + if ttt_mask in ("", "all", "baseline_all"): + pass + elif ttt_mask == "no_q": + _ttt_q_default = "0" + elif ttt_mask == "no_v": + _ttt_v_default = "0" + elif ttt_mask == "no_qv": + _ttt_q_default = "0" + _ttt_v_default = "0" + else: + raise ValueError(f"Unsupported TTT_MASK={ttt_mask!r}") + ttt_q_lora = bool(int(os.environ.get("TTT_Q_LORA", _ttt_q_default))) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_v_lora = bool(int(os.environ.get("TTT_V_LORA", _ttt_v_default))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "pergroup") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 0.5)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2500)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 3)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 7)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 14.0)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 11.5)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "1"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "1"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "1"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 0.5)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + lqer_scope = os.environ.get("LQER_SCOPE", "all") + lqer_gain_select = bool(int(os.environ.get("LQER_GAIN_SELECT", "0"))) + awq_lite_enabled = bool(int(os.environ.get("AWQ_LITE_ENABLED", "1"))) + awq_lite_bits = int(os.environ.get("AWQ_LITE_BITS", "8")) + awq_lite_group_top_k = int(os.environ.get("AWQ_LITE_GROUP_TOP_K", "1")) + awq_lite_group_size = int(os.environ.get("AWQ_LITE_GROUP_SIZE", "64")) + # PR #1145/#1967 online n-gram tilt. This is a causal scoring overlay: + # prefix-only token/within-word/word experts propose one hint token, then + # the per-token NLL is adjusted with closed-form softmax renormalization. + ngram_tilt_enabled = bool(int(os.environ.get("NGRAM_TILT_ENABLED", "1"))) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + within_tau = float(os.environ.get("WITHIN_TAU", "0.450")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.750")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "0.650")) + word_boost = float(os.environ.get("WORD_BOOST", "0.750")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + ngram_hint_precompute_outside = bool(int(os.environ.get("NGRAM_HINT_PRECOMPUTE_OUTSIDE", "1"))) + ppm_mixer_enabled = bool(int(os.environ.get("PPM_MIXER_ENABLED", "1"))) + ppm_order = int(os.environ.get("PPM_ORDER", "5")) + ppm_h = float(os.environ.get("PPM_H", "0.995")) + ppm_l = float(os.environ.get("PPM_L", "0.20")) + ppm_t = float(os.environ.get("PPM_T", "0.80")) + ppm_dump_inputs = bool(int(os.environ.get("PPM_DUMP_INPUTS", "0"))) + # JEPA-Lite training regularizer from PR #2027. The predictor is never + # serialized; it only nudges hidden states during training. + jepa_lite_enabled = bool(int(os.environ.get("JEPA_LITE_ENABLED", "0"))) + jepa_lite_weight = float(os.environ.get("JEPA_LITE_WEIGHT", "0.03")) + jepa_lite_hidden = int(os.environ.get("JEPA_LITE_HIDDEN", "128")) + jepa_lite_horizon = int(os.environ.get("JEPA_LITE_HORIZON", "4")) + jepa_lite_stride = int(os.environ.get("JEPA_LITE_STRIDE", "16")) + jepa_lite_frequency = int(os.environ.get("JEPA_LITE_FREQUENCY", "8")) + jepa_lite_start_frac = float(os.environ.get("JEPA_LITE_START_FRAC", "0.35")) + jepa_lite_end_frac = float(os.environ.get("JEPA_LITE_END_FRAC", "0.85")) + jepa_lite_lr = float(os.environ.get("JEPA_LITE_LR", "0.01")) + bijepa_enabled = bool(int(os.environ.get("BIJEPA_ENABLED", "0"))) + bijepa_backward_weight = float(os.environ.get("BIJEPA_BACKWARD_WEIGHT", "1.0")) + bijepax_enabled = bool(int(os.environ.get("BIJEPAX_ENABLED", "0"))) + bijepax_weight = float(os.environ.get("BIJEPAX_WEIGHT", "0.04")) + bijepax_start_frac = float(os.environ.get("BIJEPAX_START_FRAC", "0.05")) + bijepax_end_frac = float(os.environ.get("BIJEPAX_END_FRAC", "0.80")) + bijepax_fwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_FWD_HOPS", "1,4,16").split(",") if x + ) + bijepax_bwd_hops = tuple( + int(x) for x in os.environ.get("BIJEPAX_BWD_HOPS", "1,4").split(",") if x + ) + bijepax_cycle = bool(int(os.environ.get("BIJEPAX_CYCLE", "1"))) + bijepax_head_dim = int(os.environ.get("BIJEPAX_HEAD_DIM", "64")) + bijepax_stride = int(os.environ.get("BIJEPAX_STRIDE", "8")) + bijepax_lr = float(os.environ.get("BIJEPAX_LR", "3e-4")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "1"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + if self.caseops_enabled: + self.base_bytes_lut = None + self.has_leading_space_lut = None + self.is_boundary_token_lut = None + else: + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._prefetch_queue = [] + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + targets = torch.empty(per_rank_span - 1, dtype=torch.int64, pin_memory=True) + inputs.copy_(buf[:-1]) + targets.copy_(buf[1:]) + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + while len(self._prefetch_queue) < 2: + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + inputs, targets, cu_seqlens, max_seqlen = self._prefetch_queue.pop(0).result() + self._prefetch_queue.append( + self._batch_pool.submit(self._prepare_batch, num_tokens_local, self.max_seq_len)) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def fused_log_softmax_dual_gather_kernel( + logits_ptr, + target_ids_ptr, + hint_ids_ptr, + log_p_y_out_ptr, + log_q_h_out_ptr, + n_rows, + n_cols, + block_cols: tl.constexpr, +): + row_idx = tl.program_id(0) + if row_idx >= n_rows: + return + target = tl.load(target_ids_ptr + row_idx) + hint = tl.load(hint_ids_ptr + row_idx) + row_offset = row_idx * n_cols + target_logit = tl.load(logits_ptr + row_offset + target).to(tl.float32) + hint_logit = tl.load(logits_ptr + row_offset + hint).to(tl.float32) + max_val = -float("inf") + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=-float("inf") + ).to(tl.float32) + max_val = tl.maximum(max_val, tl.max(vals, axis=0)) + sum_exp = tl.zeros((), dtype=tl.float32) + for col_start in tl.range(0, n_cols, block_cols): + cols = col_start + tl.arange(0, block_cols) + mask = cols < n_cols + vals = tl.load( + logits_ptr + row_offset + cols, mask=mask, other=0.0 + ).to(tl.float32) + sum_exp += tl.sum(tl.where(mask, tl.exp(vals - max_val), 0.0), axis=0) + lse = max_val + tl.log(sum_exp) + tl.store(log_p_y_out_ptr + row_idx, target_logit - lse) + tl.store(log_q_h_out_ptr + row_idx, hint_logit - lse) + + +def fused_log_softmax_dual_gather(logits, target_ids, hint_ids): + bsz, seqlen, vocab = logits.shape + n_rows = bsz * seqlen + logits_flat = logits.reshape(n_rows, vocab).contiguous() + target_flat = target_ids.reshape(n_rows).contiguous() + hint_flat = hint_ids.reshape(n_rows).contiguous() + log_p_y_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + log_q_h_out = torch.empty(n_rows, dtype=torch.float32, device=logits.device) + fused_log_softmax_dual_gather_kernel[(n_rows,)]( + logits_flat, + target_flat, + hint_flat, + log_p_y_out, + log_q_h_out, + n_rows, + vocab, + block_cols=1024, + num_warps=8, + ) + return log_p_y_out.reshape(bsz, seqlen), log_q_h_out.reshape(bsz, seqlen) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.3 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.3 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.3 * c0) + aux1 = tl.where(c1 > 0, c1, 0.3 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 256, 128, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.3).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.fused_ce_enabled = bool(h.fused_ce_enabled) + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + h.qk_gain_init, + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + # V19: Asymmetric Logit Rescale (PR #1923 jorge-asenjo). + # Two learnable softcap scales applied on the EVAL path (forward_logits + + # forward_ttt). Init to logit_softcap so the layer is identity at step 0. + # Train path keeps the single fused softcap to preserve PR #1855 numerics. + self.asym_logit_enabled = bool(int(os.environ.get("ASYM_LOGIT_RESCALE", "1"))) + if self.asym_logit_enabled: + self.softcap_pos = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.softcap_neg = nn.Parameter(torch.tensor(float(h.logit_softcap), dtype=torch.float32)) + self.jepa_horizon = h.jepa_lite_horizon + self.jepa_stride = h.jepa_lite_stride + self.register_buffer( + "jepa_weight_current", + torch.zeros((), dtype=torch.float32), + persistent=False, + ) + self.jepa_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled + else None + ) + self.bijepa_enabled = bool(h.bijepa_enabled) + self.bijepa_backward_weight = float(h.bijepa_backward_weight) + self.jepa_bwd_predictor = ( + nn.Sequential( + CastedLinear(h.model_dim, h.jepa_lite_hidden, bias=False), + nn.LeakyReLU(0.5), + CastedLinear(h.jepa_lite_hidden, h.model_dim, bias=False), + ) + if h.jepa_lite_enabled and h.bijepa_enabled + else None + ) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self.tok_emb(input_ids) + # SmearGate (PR #1667). lam=0 + W=0 -> identity at init. + # Cross-doc leak fix: zero the prev-token smear at any position whose current token + # is BOS, so the BOS embedding starting doc N+1 in a packed stream is not + # contaminated by doc N's last token (audited issue on PR#1797 base). + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + return F.linear(hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def _apply_asym_softcap(self, logits): + # V19: Asymmetric softcap (PR #1923). Splits the logit_softcap scalar into + # learnable positive/negative branches. Score-first preserved: still a + # bounded, normalized post-projection nonlinearity feeding a standard + # softmax over the full vocab. + sp = self.softcap_pos.to(logits.dtype) + sn = self.softcap_neg.to(logits.dtype) + return torch.where(logits > 0, sp * torch.tanh(logits / sp), sn * torch.tanh(logits / sn)) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + if self.asym_logit_enabled: + return self._apply_asym_softcap(logits_proj) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def jepa_aux_loss(self, features): + if self.jepa_predictor is None: + return features.new_zeros(()) + horizon = max(1, int(self.jepa_horizon)) + stride = max(1, int(self.jepa_stride)) + if features.size(1) <= horizon: + return features.new_zeros(()) + src = features[:, :-horizon:stride, :] + tgt = features[:, horizon::stride, :] + n = min(src.size(1), tgt.size(1)) + if n <= 0: + return features.new_zeros(()) + src = F.normalize(src[:, :n, :].float(), dim=-1) + tgt = F.normalize(tgt[:, :n, :].detach().float(), dim=-1) + pred = F.normalize(self.jepa_predictor(src), dim=-1) + loss = (1.0 - (pred * tgt).sum(dim=-1)).mean() + if self.jepa_bwd_predictor is not None: + bwd_src = F.normalize(features[:, horizon::stride, :][:, :n, :].float(), dim=-1) + bwd_tgt = F.normalize( + features[:, :-horizon:stride, :][:, :n, :].detach().float(), + dim=-1, + ) + bwd_pred = F.normalize(self.jepa_bwd_predictor(bwd_src), dim=-1) + bwd_loss = (1.0 - (bwd_pred * bwd_tgt).sum(dim=-1)).mean() + loss = loss + self.bijepa_backward_weight * bwd_loss + return loss * self.jepa_weight_current.to( + device=features.device, + dtype=features.dtype, + ) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss + + def forward_with_hidden(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden( + input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + if self.fused_ce_enabled: + ce_loss = softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + else: + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + ce_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + if self.jepa_predictor is not None: + ce_loss = ce_loss + self.jepa_aux_loss(hidden) + return ce_loss, hidden + + def forward_ttt(self, input_ids, target_ids, lora, hint_ids=None): + x = self.tok_emb(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + # Cross-doc leak fix: see _forward_hidden comment. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + # V19: same asymmetric softcap on the TTT eval path. + if self.asym_logit_enabled: + logits = self._apply_asym_softcap(logits) + else: + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + if hint_ids is None: + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + if not logits.requires_grad: + log_p_y, log_q_h = fused_log_softmax_dual_gather( + logits, target_ids, hint_ids.clamp(min=0) + ) + return -log_p_y, log_q_h + ls = F.log_softmax(logits.float(), dim=-1) + log_p_y = ls.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) + log_q_h = ls.gather(-1, hint_ids.clamp(min=0).unsqueeze(-1)).squeeze(-1) + return -log_p_y, log_q_h + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + if lora.q_loras is not None: + q_raw = q_raw + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = F.linear(n, v_w.to(n.dtype)) + if lora.v_loras is not None: + v = v + lora.v_loras[slot](n) + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__( + self, bsz, model, rank, + q_lora=True, k_lora=True, v_lora=True, mlp_lora=True, o_lora=True, + ): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if q_lora + else None + ) + self.v_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if v_lora + else None + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad) + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update, alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd, alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + jepa_params = ( + list(base_model.jepa_predictor.parameters()) + if getattr(base_model, "jepa_predictor", None) is not None + else [] + ) + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + jepa_params.extend(base_model.jepa_bwd_predictor.parameters()) + if jepa_params: + self.optimizer_jepa = torch.optim.AdamW( + [ + { + "params": jepa_params, + "lr": h.jepa_lite_lr, + "base_lr": h.jepa_lite_lr, + } + ], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=0.0, + fused=True, + ) + else: + self.optimizer_jepa = None + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + if self.optimizer_jepa is not None: + self.optimizers.append(self.optimizer_jepa) + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_params.extend(jepa_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + self._aux_stream = torch.cuda.Stream() + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self._aux_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._aux_stream): + self.optimizer_tok.step() + self.optimizer_scalar.step() + if self.optimizer_jepa is not None: + self.optimizer_jepa.step() + self.optimizer_muon.step() + torch.cuda.current_stream().wait_stream(self._aux_stream) + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + act_sumsq = {} + act_counts = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + x_sq = x.square().sum(dim=0) + x_count = x.shape[0] + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x_sq + act_counts[name] += x_count + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + y.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += y.square().sum(dim=0) + act_counts[name] += y.shape[0] + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + h_act.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += h_act.square().sum(dim=0) + act_counts[name] += h_act.shape[0] + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + if name not in act_sumsq: + act_sumsq[name] = torch.zeros( + x.shape[1], dtype=torch.float32, device=device + ) + act_counts[name] = 0 + act_sumsq[name] += x.square().sum(dim=0) + act_counts[name] += x.shape[0] + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + act_stats = {} + for name, sumsq in act_sumsq.items(): + count = max(act_counts.get(name, 0), 1) + act_stats[name] = (sumsq / count).sqrt().cpu() + return hessians, act_stats + + +def gptq_quantize_weight( + w, + H, + clip_sigmas=3.0, + clip_range=63, + block_size=128, + protect_groups=None, + group_size=None, + protect_clip_range=None, +): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + H_flip = torch.flip(H, dims=(0, 1)) + L_flip = torch.linalg.cholesky(H_flip) + U = torch.flip(L_flip, dims=(0, 1)) + eye = torch.eye(H.shape[0], device=H.device, dtype=H.dtype) + Hinv = torch.linalg.solve_triangular(U, eye, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + protect_meta = None + protect_mask_perm = None + s_hi = None + sf_hi = None + if ( + protect_groups + and group_size is not None + and protect_clip_range is not None + and protect_clip_range > clip_range + ): + protect_mask = torch.zeros(cols, dtype=torch.bool) + starts = [] + for (start, end) in protect_groups: + if start < 0 or end > cols or end <= start: + continue + protect_mask[start:end] = True + starts.append(start) + if starts: + protect_mask_perm = protect_mask[perm] + s_hi = (clip_sigmas * row_std / protect_clip_range).clamp_min(1e-10).to( + torch.float16 + ) + sf_hi = s_hi.float() + protect_meta = { + "starts": torch.tensor(starts, dtype=torch.int16), + "size": int(group_size), + "s_hi": s_hi, + } + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + if protect_mask_perm is not None and bool(protect_mask_perm[i1 + j]): + q_col = torch.clamp( + torch.round(w_col / sf_hi), + -protect_clip_range, + protect_clip_range, + ) + w_recon = q_col.float() * sf_hi + else: + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + w_recon = q_col.float() * sf + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - w_recon) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s, protect_meta + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def _lqer_fit_quantized(E, h): + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + if r <= 0: + return None + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + A_hat = qA.float() * float(sA) + g_sz = qB.numel() // sB.numel() + B_hat = (qB.reshape(-1, g_sz).float() * sB.float().view(-1, 1)).reshape( + qB.shape + ) + return { + "kind": "asym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + A_hat = qA.float() * sA.float().view(-1, 1) + B_hat = qB.float() * sB.float().view(-1, 1) + return { + "kind": "sym", + "qA": qA, + "sA": sA, + "qB": qB, + "sB": sB, + "delta": A_hat @ B_hat, + } + + +def _awq_lite_group_candidates(w, act_rms, group_size): + cols = w.shape[1] + n_groups = cols // group_size + if n_groups <= 0: + return [] + weight_score = w.float().abs().mean(dim=0) + saliency = act_rms.float() * weight_score + cands = [] + for gi in range(n_groups): + start = gi * group_size + end = start + group_size + score = float(saliency[start:end].sum()) + cands.append((score, start, end)) + return cands + + +def gptq_mixed_quantize(state_dict, hessians, act_stats, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + awq_on = bool(getattr(h, "awq_lite_enabled", False)) + lqer_cands = {} + awq_selected = collections.defaultdict(list) + if awq_on: + awq_cands = [] + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if t.is_floating_point() and t.numel() > 65536 and name in act_stats: + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + if bits < h.awq_lite_bits: + for score, start, end in _awq_lite_group_candidates( + t, act_stats[name], h.awq_lite_group_size + ): + awq_cands.append((score, name, start, end)) + awq_cands.sort(key=lambda x: -x[0]) + for (_score, name, start, end) in awq_cands[: h.awq_lite_group_top_k]: + awq_selected[name].append((start, end)) + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + q, s, protect_meta = gptq_quantize_weight( + t, + hessians[name], + clip_sigmas=cs, + clip_range=clip_range, + protect_groups=awq_selected.get(name), + group_size=h.awq_lite_group_size if name in awq_selected else None, + protect_clip_range=(2 ** (h.awq_lite_bits - 1) - 1) + if name in awq_selected + else None, + ) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + W_q = q.float() * s.float().view(-1, 1) + if protect_meta is not None: + result[name + ".awqg_start"] = protect_meta["starts"] + result[name + ".awqg_s_hi"] = protect_meta["s_hi"] + result[name + ".awqg_size"] = torch.tensor( + protect_meta["size"], dtype=torch.int16 + ) + meta[name] = meta[name] + f"+awqgrpint{h.awq_lite_bits}" + gsz = protect_meta["size"] + for start in protect_meta["starts"].tolist(): + W_q[:, start : start + gsz] = ( + q[:, start : start + gsz].float() + * protect_meta["s_hi"].float().view(-1, 1) + ) + if lqer_on: + # LQER is fit on top of the fully realized GPTQ base, which already + # includes any higher-precision AWQ-protected groups. + scope = str(getattr(h, "lqer_scope", "all")).lower() + scope_ok = ( + scope == "all" + or (scope == "mlp" and ".mlp." in name) + or (scope == "attn" and ".attn." in name) + or (scope == "embed" and "tok_emb" in name) + ) + if scope_ok: + E = t.float() - W_q + err_norm = float(E.norm()) + if err_norm > 0: + lqer_cands[name] = (E, err_norm) + if lqer_on and lqer_cands: + if bool(getattr(h, "lqer_gain_select", False)): + scored = [] + for (name, (E, base_err)) in lqer_cands.items(): + fit = _lqer_fit_quantized(E, h) + if fit is None: + continue + new_err = float((E - fit["delta"]).norm()) + gain = base_err - new_err + if gain > 0: + scored.append((gain, name, fit)) + scored.sort(key=lambda x: -x[0]) + for (_gain, name, fit) in scored[: h.lqer_top_k]: + if fit["kind"] == "asym": + result[name + ".lqA_a"] = fit["qA"] + result[name + ".lqAs_a"] = fit["sA"] + result[name + ".lqB_a"] = fit["qB"] + result[name + ".lqBs_a"] = fit["sB"] + meta[name] = meta[name] + "+lqer_asym" + else: + result[name + ".lqA"] = fit["qA"] + result[name + ".lqAs"] = fit["sA"] + result[name + ".lqB"] = fit["qB"] + result[name + ".lqBs"] = fit["sB"] + meta[name] = meta[name] + "+lqer" + else: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "awqgrpint" in info: + starts = result[name + ".awqg_start"].tolist() + s_hi = result[name + ".awqg_s_hi"].float() + gsz = int(result[name + ".awqg_size"].item()) + for start in starts: + W[:, start : start + gsz] = ( + q[:, start : start + gsz].float() * s_hi.view(-1, 1) + ) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +# ── Per-group lrzip compression (ported from PR#1586 via PR#1667/1729) ──────── + +_GROUP_ORDER = [ + "_tok_emb.weight.q", + "attn.c_k.weight.q", "attn.c_q.weight.q", + "attn.c_v.weight.q", "attn.proj.weight.q", + "mlp.fc.weight.q", "mlp.proj.weight.q", +] +_SIMSORT_KEYS = {"_tok_emb.weight.q", "attn.c_q.weight.q", "mlp.fc.weight.q"} +_PACK_MAGIC = b"PGRP" + + +def _similarity_sort_l1(matrix): + import numpy as _np + n = matrix.shape[0] + used = _np.zeros(n, dtype=bool) + order = [0] + used[0] = True + cur = matrix[0].astype(_np.float32) + for _ in range(n - 1): + dists = _np.sum(_np.abs(matrix[~used].astype(_np.float32) - cur), axis=1) + unused = _np.where(~used)[0] + best = unused[_np.argmin(dists)] + order.append(best) + used[best] = True + cur = matrix[best].astype(_np.float32) + return _np.array(order, dtype=_np.uint16) + + +def _lrzip_compress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.bin") + out = f"{inp}.lrz" + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-z", "-L", "9", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _lrzip_decompress(data, tmpdir, label): + inp = os.path.join(tmpdir, f"{label}.lrz") + out = os.path.join(tmpdir, f"{label}.bin") + with open(inp, "wb") as f: + f.write(data) + subprocess.run(["lrzip", "-d", "-f", "-o", out, inp], capture_output=True, check=True) + with open(out, "rb") as f: + result = f.read() + os.remove(inp); os.remove(out) + return result + + +def _pack_streams(streams): + import struct + n = len(streams) + hdr = _PACK_MAGIC + struct.pack("", p) + if m: + return bytes([int(m.group(1), 16)]) + return (" " + p[1:]).encode() if p.startswith("▁") else p.encode() + + +def _ppm_mixture_bpb(tgt_np, lp_np, sp, O=4, H=0.9, L_=0.05, T=0.9, token_byte_lens_np=None): + V = sp.vocab_size() + piece_bytes = [None] * V + piece_lens = np.zeros(V, dtype=np.int32) + for i in range(V): + b = _ppm_piece_bytes(sp, i) + piece_bytes[i] = b + piece_lens[i] = len(b) + if token_byte_lens_np is None: + per_tok_len = piece_lens[tgt_np] + bs = b''.join(piece_bytes[int(t)] for t in tgt_np) + kept_lp = lp_np + else: + chunks = [] + kept_lp_parts = [] + lens_parts = [] + for t, lp, side_len in zip(tgt_np, lp_np, token_byte_lens_np): + side_len = int(side_len) + if side_len <= 0: + continue + b = piece_bytes[int(t)] + if not b: + continue + if len(b) > side_len: + b = b[:side_len] + elif len(b) < side_len: + b = b + b[-1:] * (side_len - len(b)) + chunks.append(b) + kept_lp_parts.append(float(lp)) + lens_parts.append(side_len) + if not chunks: + return float("inf") + bs = b''.join(chunks) + per_tok_len = np.asarray(lens_parts, dtype=np.int32) + kept_lp = np.asarray(kept_lp_parts, dtype=np.float64) + N = len(bs) + rep_lp = np.repeat(kept_lp.astype(np.float64), per_tok_len) + rep_len = np.repeat(per_tok_len.astype(np.float64), per_tok_len) + nlp = np.where(rep_len > 0, rep_lp / rep_len, 0.0) + tabs = [dict() for _ in range(O + 1)] + plp = np.empty(N, dtype=np.float64) + cf = np.empty(N, dtype=np.float64) + LN256 = math.log(1 / 256) + log_ = math.log + h_ctx = b'' + for i in range(N): + x = bs[i] + if i == 0: + plp[i] = LN256 + cf[i] = 1 / 256 + else: + esc = 1.0 + pf = 0.0 + cf_mx = 0 + cf_tot = 256 + cf_seen = False + lim = O if i > O else i + for o in range(lim, -1, -1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + continue + if not cf_seen: + cf_mx = e[1] + cf_tot = e[0] + cf_seen = True + tot = e[0] + d = e[2] + c = d.get(x, 0) + if c > 0: + pf = esc * (2 * c - 1) / (2 * tot) + break + esc *= len(d) / (2 * tot) + else: + pf = esc / 256 + if pf < 1e-20: + pf = 1e-20 + plp[i] = log_(pf) + cf[i] = (cf_mx / cf_tot) if cf_seen else 1 / 256 + for o in range(O + 1): + k = h_ctx[-o:] if o else b'' + e = tabs[o].get(k) + if e is None: + tabs[o][k] = [1, 1, {x: 1}] + else: + e[0] += 1 + d = e[2] + cnt = d.get(x, 0) + 1 + d[x] = cnt + if cnt > e[1]: + e[1] = cnt + h_ctx = (h_ctx + bytes([x]))[-O:] + nn_prob = np.exp(nlp) + ppm_prob = np.exp(plp) + + def _mix_bpb(Hv, Lv, Tv): + lam_v = np.where(cf > Tv, Lv, Hv) + pm_v = lam_v * nn_prob + (1 - lam_v) * ppm_prob + return float(-np.log2(np.maximum(pm_v, 1e-300)).sum() / N) + + default_bpb = _mix_bpb(H, L_, T) + if os.environ.get("PPM_SWEEP_GRID", "0") == "1": + hs = [float(x) for x in os.environ.get("PPM_SWEEP_HS", str(H)).split(",") if x.strip()] + ls = [float(x) for x in os.environ.get("PPM_SWEEP_LS", str(L_)).split(",") if x.strip()] + ts = [float(x) for x in os.environ.get("PPM_SWEEP_TS", str(T)).split(",") if x.strip()] + combo_count = len(hs) * len(ls) * len(ts) + max_combos = int(os.environ.get("PPM_SWEEP_MAX_COMBOS", "256")) + if combo_count > max_combos and os.environ.get("PPM_SWEEP_ALLOW_SLOW", "0") != "1": + log( + f"ppm_sweep skipped: combos={combo_count} max={max_combos}; " + "dump inputs and replay offline, or set PPM_SWEEP_ALLOW_SLOW=1" + ) + return default_bpb + best = (default_bpb, H, L_, T) + for Hv in hs: + for Lv in ls: + for Tv in ts: + bpb = _mix_bpb(Hv, Lv, Tv) + if bpb < best[0]: + best = (bpb, Hv, Lv, Tv) + log( + f"ppm_sweep best_bpb:{best[0]:.8f} H={best[1]} L={best[2]} T={best[3]} " + f"default_bpb:{default_bpb:.8f}" + ) + if os.environ.get("PPM_SWEEP_APPLY", "0") == "1": + return best[0] + return default_bpb + + +def eval_val_ppm_sliding(h, device, val_data, model, batch_seqs=32): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + model.eval() + seq_len = h.eval_seq_len + stride = h.eval_stride + context_size = seq_len - stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + tga_local = [] + lpa_local = [] + bla_local = [] + fwd_fn = model.module.forward_logits if hasattr(model, 'module') else model.forward_logits + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = fwd_fn(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction='none', + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + if val_data.val_bytes is not None: + tb = val_data.val_bytes[ws + s + 1: ws + wlen + 1].to(device=device, dtype=torch.float64) + else: + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + tga_local.append(tgt.cpu().to(torch.int64)) + lpa_local.append((-scored_nll).cpu().to(torch.float64)) + bla_local.append(tb.cpu().to(torch.int32)) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss, val_bpb = _loss_bpb(loss_sum, token_count, byte_count) + if h.ppm_mixer_enabled: + tga_local_cat = torch.cat(tga_local) if tga_local else torch.zeros(0, dtype=torch.int64) + lpa_local_cat = torch.cat(lpa_local) if lpa_local else torch.zeros(0, dtype=torch.float64) + bla_local_cat = torch.cat(bla_local) if bla_local else torch.zeros(0, dtype=torch.int32) + if dist.is_available() and dist.is_initialized(): + local_size = torch.tensor([tga_local_cat.numel()], dtype=torch.int64, device=device) + sizes = [torch.zeros(1, dtype=torch.int64, device=device) for _ in range(h.world_size)] + dist.all_gather(sizes, local_size) + sizes_list = [int(s.item()) for s in sizes] + max_size = max(sizes_list) if sizes_list else 0 + tga_pad = torch.zeros(max_size, dtype=torch.int64, device=device) + lpa_pad = torch.zeros(max_size, dtype=torch.float64, device=device) + bla_pad = torch.zeros(max_size, dtype=torch.int32, device=device) + tga_pad[:tga_local_cat.numel()] = tga_local_cat.to(device) + lpa_pad[:lpa_local_cat.numel()] = lpa_local_cat.to(device) + bla_pad[:bla_local_cat.numel()] = bla_local_cat.to(device) + if h.rank == 0: + gather_t = [torch.zeros(max_size, dtype=torch.int64, device=device) for _ in range(h.world_size)] + gather_l = [torch.zeros(max_size, dtype=torch.float64, device=device) for _ in range(h.world_size)] + gather_b = [torch.zeros(max_size, dtype=torch.int32, device=device) for _ in range(h.world_size)] + else: + gather_t = None + gather_l = None + gather_b = None + dist.gather(tga_pad, gather_t, dst=0) + dist.gather(lpa_pad, gather_l, dst=0) + dist.gather(bla_pad, gather_b, dst=0) + if h.rank == 0: + tga_full = torch.cat([gather_t[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + lpa_full = torch.cat([gather_l[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + bla_full = torch.cat([gather_b[r][:sizes_list[r]] for r in range(h.world_size)]).cpu().numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_full.astype(np.int64), + logp=lpa_full.astype(np.float64), + byte_lens=bla_full.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_full, lpa_full, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_full) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_full.size} N_sidecar_bytes={int(bla_full.sum())}') + val_bpb = mixer_bpb + else: + tga_np = tga_local_cat.numpy() + lpa_np = lpa_local_cat.numpy() + bla_np = bla_local_cat.numpy() + if getattr(h, "ppm_dump_inputs", False): + dump_path = os.path.join(h.artifact_dir or ".", f"{h.run_id}.ppm_inputs.npz") + np.savez_compressed( + dump_path, + target_ids=tga_np.astype(np.int64), + logp=lpa_np.astype(np.float64), + byte_lens=bla_np.astype(np.int32), + ) + log(f"ppm_dump_inputs:{dump_path}") + t0 = time.perf_counter() + mixer_bpb = _ppm_mixture_bpb(tga_np, lpa_np, val_data.sp, O=h.ppm_order, H=h.ppm_h, L_=h.ppm_l, T=h.ppm_t, token_byte_lens_np=bla_np) + log(f'ppm_mixer val_bpb:{mixer_bpb:.8f} eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms order={h.ppm_order} H={h.ppm_h} L={h.ppm_l} T={h.ppm_t} N_tokens={lpa_np.size} N_sidecar_bytes={int(bla_np.sum())}') + val_bpb = mixer_bpb + model.train() + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( + device=device, dtype=torch.int32, non_blocking=True + ) + val_byte_count += sidecar_slice.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def jepa_weight_at(h, step, frac): + if not h.jepa_lite_enabled: + return 0.0 + freq = max(1, h.jepa_lite_frequency) + if step % freq != 0 or frac < h.jepa_lite_start_frac: + return 0.0 + weight = h.jepa_lite_weight + if frac > h.jepa_lite_end_frac: + weight *= max( + 0.0, + (1.0 - frac) / max(1e-09, 1.0 - h.jepa_lite_end_frac), + ) + return weight + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + if os.environ.get("VAL_DOC_PREFIX_ONLY", "0") == "1": + return doc_entries[:sample_n] + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + optimizer.zero_grad(set_to_none=True) + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + optimizer.step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def _compute_ngram_hints_for_val(h, val_data, log0=print): + if not getattr(h, "ngram_tilt_enabled", False): + return None + from online_ngram_tilt import build_hints_for_targets + + all_tokens = val_data.val_tokens + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + max_targets = int(os.environ.get("NGRAM_HINT_MAX_TARGETS", "0")) + target_count = targets_np_all.shape[0] + if max_targets > 0: + targets_np = targets_np_all[: min(max_targets, target_count)] + else: + targets_np = targets_np_all + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log0, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + if hint_global.numel() < target_count: + padded_hint = torch.zeros(target_count, dtype=torch.int64) + padded_gate = torch.zeros(target_count, dtype=torch.bool) + padded_boost = torch.zeros(target_count, dtype=torch.float32) + padded_hint[: hint_global.numel()] = hint_global + padded_gate[: gate_global.numel()] = gate_global + padded_boost[: boost_global.numel()] = boost_global + hint_global, gate_global, boost_global = padded_hint, padded_gate, padded_boost + log0( + f"ngram_tilt:precompute_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={hint_global.numel()} computed_targets={targets_np.shape[0]}" + ) + return hint_global, gate_global, boost_global + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, precomputed_hints=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + ngram_hint_global = None + ngram_gate_global = None + ngram_boost_global = None + if precomputed_hints is not None: + ngram_hint_global, ngram_gate_global, ngram_boost_global = precomputed_hints + log( + "ngram_tilt:using_precomputed_hints " + f"total_targets={ngram_hint_global.numel()}" + ) + elif getattr(h, "ngram_tilt_enabled", False): + ngram_hint_global, ngram_gate_global, ngram_boost_global = _compute_ngram_hints_for_val( + h, val_data, log0=log + ) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + + def _build_opt(lora): + local_lr = h.ttt_lora_lr * h.ttt_local_lr_mult + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=local_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=local_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + hint_ids_gpu = None + gate_mask_gpu = None + boost_gpu = None + if ngram_hint_global is not None: + hint_idx_cpu = ( + tok_starts.unsqueeze(1) + col_idx[:context_size].unsqueeze(0) + ).clamp_(min=0, max=ngram_hint_global.numel() - 1) + hint_ids_gpu = ngram_hint_global[hint_idx_cpu].to( + device=device, dtype=torch.int64, non_blocking=True + ) + gate_mask_gpu = ngram_gate_global[hint_idx_cpu].to( + device=device, non_blocking=True + ) + boost_gpu = ngram_boost_global[hint_idx_cpu].to( + device=device, dtype=torch.float32, non_blocking=True + ) + hint_ids_gpu = torch.where(valid, hint_ids_gpu, torch.zeros_like(hint_ids_gpu)) + gate_mask_gpu = gate_mask_gpu & valid + log_q_hint = None + if hint_ids_gpu is not None: + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss, log_q_hint = forward_ttt_train( + x, y, lora=cur_lora, hint_ids=hint_ids_gpu + ) + else: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + if hint_ids_gpu is not None and log_q_hint is not None: + from online_ngram_tilt import apply_tilt_to_ptl_torch_fast + + scored_loss = apply_tilt_to_ptl_torch_fast( + ptl=per_tok_loss, + log_q_hint=log_q_hint, + target_ids=y, + hint_ids=hint_ids_gpu, + gate_mask=gate_mask_gpu, + boost=boost_gpu, + ) + else: + scored_loss = per_tok_loss + with torch.no_grad(): + _accumulate_bpb( + scored_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if scored_loss is not per_tok_loss: + del scored_loss + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + train_x, train_y = x, y + train_chunk_offset = chunk_offset + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > 0 and context_size > max(train_window, chunk_size): + train_window = max(train_window, chunk_size) + train_end = min(context_size, chunk_offset + chunk_size) + train_start = max(0, train_end - train_window) + train_x = x[:, train_start:train_end].contiguous() + train_y = y[:, train_start:train_end].contiguous() + train_chunk_offset = chunk_offset - train_start + for gi in range(h.ttt_grad_steps): + if hint_ids_gpu is not None or gi > 0: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + train_per_tok_loss = forward_ttt_train( + train_x, train_y, lora=cur_lora + ) + else: + train_per_tok_loss = per_tok_loss + per_doc = train_per_tok_loss[ + :, train_chunk_offset : train_chunk_offset + chunk_size + ].mean(dim=-1) + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + cur_opt.step() + if train_per_tok_loss is not per_tok_loss: + del train_per_tok_loss + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compile_enabled = os.environ.get("DISABLE_COMPILE", "0") != "1" + if compile_enabled: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + else: + log("compile:disabled_by_env") + compiled_model = base_model + compiled_forward_logits = base_model.forward_logits + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + if h.jepa_lite_enabled: + log( + "jepa_lite_config: " + f"weight={h.jepa_lite_weight} hidden={h.jepa_lite_hidden} " + f"horizon={h.jepa_lite_horizon} stride={h.jepa_lite_stride} " + f"frequency={h.jepa_lite_frequency} start_frac={h.jepa_lite_start_frac} " + f"end_frac={h.jepa_lite_end_frac} lr={h.jepa_lite_lr}" + ) + if h.bijepa_enabled: + log( + "bijepa_config: " + f"backward_weight={h.bijepa_backward_weight} " + "dual_predictor=true training_only=true" + ) + bijepax_module = None + bijepax_opt = None + forward_with_hidden = None + if h.bijepax_enabled: + bijepax_module = MultiDirectionalBiJEPAX( + d_model=h.model_dim, + fwd_hops=h.bijepax_fwd_hops, + bwd_hops=h.bijepax_bwd_hops, + cycle=h.bijepax_cycle, + head_dim=h.bijepax_head_dim, + ).to(device).bfloat16() + bijepax_opt = torch.optim.Adam( + bijepax_module.parameters(), lr=h.bijepax_lr, betas=(0.9, 0.99) + ) + forward_with_hidden = ( + torch.compile(base_model.forward_with_hidden, dynamic=False, fullgraph=True) + if compile_enabled + else base_model.forward_with_hidden + ) + log( + "bijepax_config: " + f"weight={h.bijepax_weight} start={h.bijepax_start_frac} " + f"end={h.bijepax_end_frac} fwd={h.bijepax_fwd_hops} " + f"bwd={h.bijepax_bwd_hops} cycle={h.bijepax_cycle} " + f"head_dim={h.bijepax_head_dim} stride={h.bijepax_stride} " + f"lr={h.bijepax_lr} training_only=true" + ) + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + _clip_params = [p for p in base_model.parameters() if p.requires_grad] + def step_fn(step, lr_scale, frac_now=0.0): + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_weight_current.fill_(jepa_weight_at(h, step, frac_now)) + bijepax_weight = ( + bijepax_weight_at( + frac_now, + h.bijepax_start_frac, + h.bijepax_end_frac, + h.bijepax_weight, + ) + if bijepax_module is not None + else 0.0 + ) + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if bijepax_module is not None and bijepax_weight > 0.0: + ce_loss, hidden = forward_with_hidden( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + loss = ce_loss + bijepax_module( + hidden, bijepax_weight, h.bijepax_stride + ) + else: + loss = model( + x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len + ) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + if step <= h.muon_momentum_warmup_steps: + + frac = ( + + min(step / h.muon_momentum_warmup_steps, 1.0) + + if h.muon_momentum_warmup_steps > 0 + + else 1.0 + + ) + + muon_momentum = ( + + 1 - frac + + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + + for group in optimizers.optimizer_muon.param_groups: + + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(_clip_params, h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + if bijepax_opt is not None and bijepax_weight > 0.0: + bijepax_opt.step() + bijepax_opt.zero_grad(set_to_none=True) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + _live_state = base_model.state_dict(keep_vars=True) + ema_state = { + name: t.detach().float().clone() + for (name, t) in _live_state.items() + if not name.startswith("jepa_predictor.") + and not name.startswith("jepa_bwd_predictor.") + and name != "jepa_weight_current" + } + _ema_pairs = [ + (ema_state[name], t) + for (name, t) in _live_state.items() + if name in ema_state + ] + ema_decay = h.ema_decay + training_time_ms = 0.0 + forced_stop_step = int(os.environ.get("FORCE_STOP_STEP", "0")) + stop_after_step = forced_stop_step if forced_stop_step > 0 else None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale, frac) + with torch.no_grad(): + for ema_t, t in _ema_pairs: + ema_t.mul_(ema_decay).add_(t.detach(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + forced_stop_step <= 0 + and max_wallclock_ms is not None + and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and forced_stop_step <= 0 and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + if getattr(base_model, "jepa_predictor", None) is not None: + base_model.jepa_predictor = None + log("jepa_lite_predictor_removed_before_serialize: true") + if getattr(base_model, "jepa_bwd_predictor", None) is not None: + base_model.jepa_bwd_predictor = None + log("bijepa_backward_predictor_removed_before_serialize: true") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + global BOS_ID + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + quantize_only = os.environ.get("QUANTIZE_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_eval_only: true") + elif quantize_only: + log("QUANTIZE_ONLY=1 — skipping training, loading saved full-precision checkpoint") + log(f"quantize_only checkpoint: {h.model_path}") + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_for_quantize_only: true") + if BOS_ID is None: + BOS_ID = 1 + base_model = GPT(h).to(device).bfloat16() + state = torch.load(h.model_path, map_location="cpu") + base_model.load_state_dict(state, strict=True) + del state + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + else: + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + if h.jepa_lite_enabled: + h.jepa_lite_enabled = False + log("jepa_lite_disabled_after_training: true") + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if os.environ.get("PREQUANT_ONLY", "0") == "1": + log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") + return + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + if os.environ.get("DISABLE_COMPILE", "0") == "1": + log("eval_compile:disabled_by_env") + compiled_model = eval_model + compiled_forward_logits = eval_model.forward_logits + else: + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if h.ttt_enabled or not h.ppm_mixer_enabled: + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + def _fwd_ttt_hint_inner(input_ids, target_ids, lora, hint_ids): + return ttt_model.forward_ttt( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + + _fwd_ttt_compiled_inner = None + _fwd_ttt_hint_compiled_inner = None + + def _fwd_ttt(input_ids, target_ids, lora, hint_ids=None): + nonlocal _fwd_ttt_compiled_inner, _fwd_ttt_hint_compiled_inner + if os.environ.get("DISABLE_COMPILE", "0") == "1": + if hint_ids is not None: + return _fwd_ttt_hint_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + return _fwd_ttt_inner(input_ids, target_ids, lora=lora) + if hint_ids is not None: + if _fwd_ttt_hint_compiled_inner is None: + _fwd_ttt_hint_compiled_inner = torch.compile( + _fwd_ttt_hint_inner, dynamic=True + ) + return _fwd_ttt_hint_compiled_inner( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_lora:warming up compile (random tokens, no val data)") + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + q_lora=h.ttt_q_lora, k_lora=h.ttt_k_lora, v_lora=h.ttt_v_lora, + mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr * h.ttt_local_lr_mult, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + train_warmup_lens = [h.ttt_chunk_size] + train_window = int(getattr(h, "ttt_train_window_tokens", 0)) + if train_window > h.ttt_chunk_size: + train_warmup_lens.append(train_window) + for ctx_len in train_warmup_lens: + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + if h.ngram_tilt_enabled: + ctx_len = h.ttt_eval_seq_len + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + hintw = torch.randint( + 0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64 + ) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + fwd_ttt_compiled(xw, yw, lora=wl, hint_ids=hintw) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + precomputed_hints = None + if h.ngram_tilt_enabled and h.ngram_hint_precompute_outside: + log("ngram_tilt:precomputing hints before TTT eval timer") + precomputed_hints = _compute_ngram_hints_for_val(h, val_data, log0=log) + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, + ttt_model, + device, + val_data, + forward_ttt_train=fwd_ttt_compiled, + precomputed_hints=precomputed_hints, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + if h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, ttt_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del ttt_model + elif h.ppm_mixer_enabled: + import sys as _sys + log("beginning PPM sliding eval") + _sys.stdout.flush() + torch.cuda.synchronize() + if dist.is_available() and dist.is_initialized(): + dist.barrier() + t_ppm = time.perf_counter() + try: + ppm_val_loss, ppm_val_bpb = eval_val_ppm_sliding( + h, device, val_data, eval_model, batch_seqs=16 + ) + torch.cuda.synchronize() + ppm_elapsed = time.perf_counter() - t_ppm + log( + f"ppm_sliding val_loss:{ppm_val_loss:.8f} val_bpb:{ppm_val_bpb:.8f} " + f"eval_time:{1e3*ppm_elapsed:.0f}ms" + ) + except Exception as _e: + log(f"PPM eval error: {_e}") + import traceback as _tb + log(_tb.format_exc()) + _sys.stdout.flush() + del eval_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 64 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed314.log b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed314.log new file mode 100644 index 0000000000..efd801db03 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed314.log @@ -0,0 +1,244 @@ +W0501 02:52:10.323000 979375 torch/distributed/run.py:803] +W0501 02:52:10.323000 979375 torch/distributed/run.py:803] ***************************************** +W0501 02:52:10.323000 979375 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0501 02:52:10.323000 979375 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.99 + bijepa_backward_weight: 1.0 + bijepa_enabled: False + bijepax_bwd_hops: (4,) + bijepax_cycle: False + bijepax_enabled: True + bijepax_end_frac: 0.8 + bijepax_fwd_hops: (4,) + bijepax_head_dim: 32 + bijepax_lr: 0.001 + bijepax_start_frac: 0.35 + bijepax_stride: 64 + bijepax_weight: 0.01 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 512 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + jepa_lite_enabled: False + jepa_lite_end_frac: 0.85 + jepa_lite_frequency: 8 + jepa_lite_hidden: 128 + jepa_lite_horizon: 4 + jepa_lite_lr: 0.01 + jepa_lite_start_frac: 0.35 + jepa_lite_stride: 16 + jepa_lite_weight: 0.03 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: True + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2500 + ppm_dump_inputs: False + ppm_h: 0.999 + ppm_l: 0.18 + ppm_mixer_enabled: True + ppm_order: 5 + ppm_t: 0.8 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: v15_bijepaxlite_lqer1_nocompile_s314_20260501_025209 + scalar_lr: 0.02 + seed: 314 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 250 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: False + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_local_lr_mult: 0.75 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mask: no_qv + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_q_lora: False + ttt_train_window_tokens: 0 + ttt_v_lora: False + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_bytes_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +compile:disabled_by_env +model_params:35945673 +bijepax_config: weight=0.01 start=0.35 end=0.8 fwd=(4,) bwd=(4,) cycle=False head_dim=32 stride=64 lr=0.001 training_only=true +gptq:reserving 0s, effective=599500ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 8.9980 val_bpb: 4.1115 +1/20000 train_loss: 8.9988 train_time: 0.0m tok/s: 4316911 +2/20000 train_loss: 12.8490 train_time: 0.0m tok/s: 3773018 +3/20000 train_loss: 10.2358 train_time: 0.0m tok/s: 3620744 +4/20000 train_loss: 8.6660 train_time: 0.0m tok/s: 3552879 +5/20000 train_loss: 7.9135 train_time: 0.0m tok/s: 3514777 +250/20000 train_loss: 2.9125 train_time: 1.0m tok/s: 3364437 +500/20000 train_loss: 2.5540 train_time: 1.9m tok/s: 3366468 +750/20000 train_loss: 2.6987 train_time: 2.9m tok/s: 3367170 +layer_loop:enabled step:899 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/20000 train_loss: 2.7476 train_time: 4.1m tok/s: 3205099 +1250/20000 train_loss: 2.5457 train_time: 5.5m tok/s: 2953029 +1500/20000 train_loss: 2.5150 train_time: 7.0m tok/s: 2805985 +1750/20000 train_loss: 2.5345 train_time: 8.5m tok/s: 2709881 +2000/20000 train_loss: 2.4815 train_time: 9.9m tok/s: 2642045 +2012/20000 val_loss: 2.4311 val_bpb: 1.1108 +stopping_early: wallclock_cap train_time: 599586ms step: 2012/20000 +peak memory allocated: 46676 MiB reserved: 52266 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.42899363 val_bpb:1.10988323 eval_time:9910ms +Serialized model: 135418111 bytes +Code size (uncompressed): 213292 bytes +Code size (compressed): 41999 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gate_int8_row: blocks.attn.attn_gate_w + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 105.0s +Serialized model quantized+pergroup: 15957540 bytes +Total submission size quantized+pergroup: 15999539 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.7s +eval_compile:disabled_by_env +diagnostic quantized val_loss:2.44155528 val_bpb:1.11562304 eval_time:9926ms +beginning PPM sliding eval +ppm_mixer val_bpb:0.97206308 eval_time:453715ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45044876 val_bpb:0.97206308 eval_time:499038ms diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed42.log b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed42.log new file mode 100644 index 0000000000..6d399aae21 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed42.log @@ -0,0 +1,244 @@ +W0501 02:24:06.229000 907407 torch/distributed/run.py:803] +W0501 02:24:06.229000 907407 torch/distributed/run.py:803] ***************************************** +W0501 02:24:06.229000 907407 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0501 02:24:06.229000 907407 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.99 + bijepa_backward_weight: 1.0 + bijepa_enabled: False + bijepax_bwd_hops: (4,) + bijepax_cycle: False + bijepax_enabled: True + bijepax_end_frac: 0.8 + bijepax_fwd_hops: (4,) + bijepax_head_dim: 32 + bijepax_lr: 0.001 + bijepax_start_frac: 0.35 + bijepax_stride: 64 + bijepax_weight: 0.01 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 512 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + jepa_lite_enabled: False + jepa_lite_end_frac: 0.85 + jepa_lite_frequency: 8 + jepa_lite_hidden: 128 + jepa_lite_horizon: 4 + jepa_lite_lr: 0.01 + jepa_lite_start_frac: 0.35 + jepa_lite_stride: 16 + jepa_lite_weight: 0.03 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: True + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2500 + ppm_dump_inputs: False + ppm_h: 0.999 + ppm_l: 0.18 + ppm_mixer_enabled: True + ppm_order: 5 + ppm_t: 0.8 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: v15_bijepaxlite_lqer1_nocompile_s42_20260501_022405 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 250 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: False + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_local_lr_mult: 0.75 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mask: no_qv + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_q_lora: False + ttt_train_window_tokens: 0 + ttt_v_lora: False + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_bytes_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +compile:disabled_by_env +model_params:35945673 +bijepax_config: weight=0.01 start=0.35 end=0.8 fwd=(4,) bwd=(4,) cycle=False head_dim=32 stride=64 lr=0.001 training_only=true +gptq:reserving 0s, effective=599500ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0076 val_bpb: 4.1159 +1/20000 train_loss: 9.0087 train_time: 0.0m tok/s: 4294285 +2/20000 train_loss: 12.8207 train_time: 0.0m tok/s: 3779390 +3/20000 train_loss: 10.2260 train_time: 0.0m tok/s: 3626892 +4/20000 train_loss: 8.6824 train_time: 0.0m tok/s: 3557519 +5/20000 train_loss: 7.9442 train_time: 0.0m tok/s: 3518508 +250/20000 train_loss: 2.9190 train_time: 1.0m tok/s: 3366377 +500/20000 train_loss: 2.5569 train_time: 1.9m tok/s: 3370798 +750/20000 train_loss: 2.7019 train_time: 2.9m tok/s: 3372986 +layer_loop:enabled step:900 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/20000 train_loss: 2.7481 train_time: 4.1m tok/s: 3211090 +1250/20000 train_loss: 2.5484 train_time: 5.5m tok/s: 2956874 +1500/20000 train_loss: 2.5173 train_time: 7.0m tok/s: 2808858 +1750/20000 train_loss: 2.5339 train_time: 8.5m tok/s: 2712218 +2000/20000 train_loss: 2.4834 train_time: 9.9m tok/s: 2643994 +2014/20000 val_loss: 2.4304 val_bpb: 1.1105 +stopping_early: wallclock_cap train_time: 599843ms step: 2014/20000 +peak memory allocated: 46676 MiB reserved: 52266 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.42847180 val_bpb:1.10964479 eval_time:9908ms +Serialized model: 135418111 bytes +Code size (uncompressed): 213292 bytes +Code size (compressed): 41999 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gate_int8_row: blocks.attn.attn_gate_w + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 99.9s +Serialized model quantized+pergroup: 15955181 bytes +Total submission size quantized+pergroup: 15997180 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.4s +eval_compile:disabled_by_env +diagnostic quantized val_loss:2.44116551 val_bpb:1.11544494 eval_time:10342ms +beginning PPM sliding eval +ppm_mixer val_bpb:0.97234287 eval_time:456845ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45118426 val_bpb:0.97234287 eval_time:502131ms diff --git a/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed999.log b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed999.log new file mode 100644 index 0000000000..771bcfde23 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_BIJEPAXLite_JEPA_PPM_0.97271/train_seed999.log @@ -0,0 +1,244 @@ +W0501 03:14:08.488000 1050941 torch/distributed/run.py:803] +W0501 03:14:08.488000 1050941 torch/distributed/run.py:803] ***************************************** +W0501 03:14:08.488000 1050941 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0501 03:14:08.488000 1050941 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + agree_add_boost: 0.5 + artifact_dir: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407 + attn_clip_sigmas: 13.0 + attn_out_gate_enabled: False + attn_out_gate_src: proj + awq_lite_bits: 8 + awq_lite_enabled: True + awq_lite_group_size: 64 + awq_lite_group_top_k: 1 + beta1: 0.9 + beta2: 0.99 + bijepa_backward_weight: 1.0 + bijepa_enabled: False + bijepax_bwd_hops: (4,) + bijepax_cycle: False + bijepax_enabled: True + bijepax_end_frac: 0.8 + bijepax_fwd_hops: (4,) + bijepax_head_dim: 32 + bijepax_lr: 0.001 + bijepax_start_frac: 0.35 + bijepax_stride: 64 + bijepax_weight: 0.01 + caseops_enabled: True + compressor: pergroup + data_dir: ./data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved + distributed: True + ema_decay: 0.9965 + embed_bits: 7 + embed_clip_sigmas: 14.0 + embed_lr: 0.6 + embed_wd: 0.085 + enable_looping_at: 0.35 + eval_seq_len: 2048 + eval_stride: 512 + fused_ce_enabled: True + gate_window: 12 + gated_attn_enabled: False + gated_attn_init_std: 0.01 + gated_attn_quant_gate: True + global_ttt_batch_seqs: 32 + global_ttt_chunk_tokens: 32768 + global_ttt_epochs: 1 + global_ttt_grad_clip: 1.0 + global_ttt_lr: 0.001 + global_ttt_momentum: 0.9 + global_ttt_respect_doc_boundaries: True + global_ttt_warmup_chunks: 0 + global_ttt_warmup_start_lr: 0.0 + gptq_calibration_batches: 16 + gptq_reserve_seconds: 0.5 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + is_main_process: True + iterations: 20000 + jepa_lite_enabled: False + jepa_lite_end_frac: 0.85 + jepa_lite_frequency: 8 + jepa_lite_hidden: 128 + jepa_lite_horizon: 4 + jepa_lite_lr: 0.01 + jepa_lite_start_frac: 0.35 + jepa_lite_stride: 16 + jepa_lite_weight: 0.03 + ln_scale: True + local_rank: 0 + logfile: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + lqer_asym_enabled: True + lqer_asym_group: 64 + lqer_enabled: True + lqer_factor_bits: 4 + lqer_gain_select: False + lqer_rank: 4 + lqer_scope: all + lqer_top_k: 1 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.026 + max_wallclock_seconds: 600.0 + min_lr: 0.1 + mlp_clip_sigmas: 11.5 + mlp_mult: 4.0 + model_dim: 512 + model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/final_model.pt + muon_backend_steps: 5 + muon_momentum: 0.97 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + ngram_hint_precompute_outside: True + ngram_tilt_enabled: True + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_final_lane: mean + parallel_start_layer: 8 + phased_ttt_num_phases: 3 + phased_ttt_prefix_docs: 2500 + ppm_dump_inputs: False + ppm_h: 0.999 + ppm_l: 0.18 + ppm_mixer_enabled: True + ppm_order: 5 + ppm_t: 0.8 + qk_gain_init: 5.25 + quantized_model_path: /workspace/parameter-golf/our_submission/1000/runs/v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407/final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + rope_yarn: False + run_id: v15_bijepaxlite_lqer1_nocompile_s999_20260501_031407 + scalar_lr: 0.02 + seed: 999 + skip_gates_enabled: True + smear_gate_enabled: True + sparse_attn_gate_enabled: True + sparse_attn_gate_init_std: 0.0 + sparse_attn_gate_scale: 0.5 + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + token_boost: 2.625 + token_order: 16 + token_threshold: 0.8 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_train_*.bin + train_log_every: 250 + train_seq_len: 2048 + ttt_batch_size: 64 + ttt_beta1: 0.0 + ttt_beta2: 0.99 + ttt_chunk_size: 48 + ttt_enabled: False + ttt_eval_batches: + ttt_eval_seq_len: 2048 + ttt_grad_steps: 1 + ttt_k_lora: True + ttt_local_lr_mult: 0.75 + ttt_lora_lr: 0.0001 + ttt_lora_rank: 80 + ttt_mask: no_qv + ttt_mlp_lora: True + ttt_o_lora: True + ttt_optimizer: adam + ttt_q_lora: False + ttt_train_window_tokens: 0 + ttt_v_lora: False + ttt_weight_decay: 0.5 + val_batch_tokens: 524288 + val_bytes_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_bytes_*.bin + val_doc_fraction: 1.0 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.85 + warmup_steps: 20 + within_boost: 0.75 + within_tau: 0.45 + word_boost: 0.75 + word_normalize: strip_punct_lower + word_order: 4 + word_tau: 0.65 + world_size: 8 + xsa_last_n: 11 +train_shards: 80 +val_tokens: 47851520 +compile:disabled_by_env +model_params:35945673 +bijepax_config: weight=0.01 start=0.35 end=0.8 fwd=(4,) bwd=(4,) cycle=False head_dim=32 stride=64 lr=0.001 training_only=true +gptq:reserving 0s, effective=599500ms +warmup_cu_buckets:64,128,192,256 iters_each:3 +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0149 val_bpb: 4.1192 +1/20000 train_loss: 9.0152 train_time: 0.0m tok/s: 4307036 +2/20000 train_loss: 12.8530 train_time: 0.0m tok/s: 3767064 +3/20000 train_loss: 10.1713 train_time: 0.0m tok/s: 3626995 +4/20000 train_loss: 8.6216 train_time: 0.0m tok/s: 3559304 +5/20000 train_loss: 7.9225 train_time: 0.0m tok/s: 3520003 +250/20000 train_loss: 2.9185 train_time: 1.0m tok/s: 3365671 +500/20000 train_loss: 2.5589 train_time: 1.9m tok/s: 3370654 +750/20000 train_loss: 2.7064 train_time: 2.9m tok/s: 3370035 +layer_loop:enabled step:899 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +1000/20000 train_loss: 2.7514 train_time: 4.1m tok/s: 3206994 +1250/20000 train_loss: 2.5502 train_time: 5.5m tok/s: 2953990 +1500/20000 train_loss: 2.5201 train_time: 7.0m tok/s: 2806608 +1750/20000 train_loss: 2.5365 train_time: 8.5m tok/s: 2710429 +2000/20000 train_loss: 2.4844 train_time: 9.9m tok/s: 2642524 +2013/20000 val_loss: 2.4352 val_bpb: 1.1127 +stopping_early: wallclock_cap train_time: 599821ms step: 2013/20000 +peak memory allocated: 46676 MiB reserved: 52266 MiB +ema:applying EMA weights +diagnostic pre-quantization post-ema val_loss:2.43314506 val_bpb:1.11178015 eval_time:9911ms +Serialized model: 135418111 bytes +Code size (uncompressed): 213292 bytes +Code size (compressed): 41999 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 4.0s +Quantized weights: + gate_int8_row: blocks.attn.attn_gate_w + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int7)+awqgrpint8+lqer_asym: tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, parallel_post_lambdas, parallel_resid_lambdas, skip_gates, skip_weights, smear_gate.weight, smear_lambda, softcap_neg, softcap_pos +Serialize: per-group lrzip compression... +Serialize: per-group compression done in 100.8s +Serialized model quantized+pergroup: 15955594 bytes +Total submission size quantized+pergroup: 15997593 bytes +Deserialize: per-group lrzip decompression... +Deserialize: decompression done in 17.3s +eval_compile:disabled_by_env +diagnostic quantized val_loss:2.44582432 val_bpb:1.11757370 eval_time:11393ms +beginning PPM sliding eval +ppm_mixer val_bpb:0.97373767 eval_time:451054ms order=5 H=0.999 L=0.18 T=0.8 N_tokens=47851520 N_sidecar_bytes=151074499 +ppm_sliding val_loss:2.45502055 val_bpb:0.97373767 eval_time:496384ms