diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/README.md b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/README.md new file mode 100644 index 0000000000..86196e06bf --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/README.md @@ -0,0 +1,108 @@ +# Fractal Recurrent Primitive Hybrid: Non-Record Research Submission + +**Track:** `track_non_record_16mb` +**Author:** Joseph Abraham ([@abbudjoe](https://github.com/abbudjoe)) +**Hardware:** 1xH100 80GB +**Status:** Non-record, research contribution +**Best recurrent-primitive artifact:** 1.357619 BPB, loss 2.292283, 14,440,584 bytes +**Source training run:** 10-minute SP1024 80-shard checkpoint, 988 steps, 600.491s +**Seed coverage:** single matching seed (`seed=42`), not a 3-seed mean + +## Short Version + +This submission documents a controlled attempt to bring a custom Fractal recurrent primitive into the Parameter Golf stack. The primitive was ported as a single middle recurrent layer inside an otherwise transformer-derived 11L/512 SP1024 model, using schedule `AAAAAPAAAAA`. + +The result is not a leaderboard record. The best recurrent-primitive hybrid trails the pure-attention control under the same checkpoint requantization sweep. The useful finding is narrower: the recurrent primitive can stay close under the 16MB cap when protected with all-large-int8 quantization, but naive attention replacement is not yet the right insertion contract. + +## Attribution and Leaderboard Provenance + +This experiment was guided by the public leaderboard meta, especially the current #1 record. The recurrent primitive is the new variable here; the surrounding stack borrows several proven ingredients from prior public work. + +| Used ingredient | Credit | How it appears here | +|---|---|---| +| Mixed int6/int8 quantization pressure and protected higher-precision export variants | PR #1394 @clarkkev | The source runs use mixed int6 clipsearch + zstd, and the best 10-minute export is an all-large-int8/zstd protection sweep. | +| Learnable per-head QK gain machinery | #1 record stack | The transformer attention path includes learnable per-head query scaling. | +| EMA 0.9965 and warmdown-style schedules | PR #1445 @X-Abhishek-X | Both the 10-minute and 60-minute recurrent runs use EMA decay 0.9965; the 60-minute probe also uses a longer warmdown schedule. | + +## What Was Tested + +The intended ablation was one variable at a time: + +- Keep tokenizer, SP1024 data path, optimizer, training budget, evaluation path, and quantization machinery fixed. +- Replace one middle attention/MLP transformer block with the Fractal recurrent primitive. +- Use the Triton runtime and block-structured recurrent state path from the Fractal prototype. +- Compare the trained recurrent-primitive checkpoint against a pure-attention control through the same quantization sweep. + +This PR includes the training script snapshot, recurrent runtime files, checkpoint requantization helper, and run summaries/logs for both the recurrent-primitive hybrid and the pure-attention control. + +## Results + +| Experiment | Model | Quant/export | Pre BPB | Post BPB | Post loss | Bytes | Notes | +|---|---:|---|---:|---:|---:|---:|---| +| Pure attention control | 11L/512 SP1024 | mixed int6 default | 1.343710 | 1.359737 | 2.295859 | 10,294,744 | Same quant sweep control | +| Pure attention control | 11L/512 SP1024 | all-large-int8 | 1.343710 | 1.344724 | 2.270510 | 14,966,424 | Best control export | +| Fractal recurrent hybrid | 11L/512 SP1024, `AAAAAPAAAAA` | mixed int6 default | 1.356221 | 1.376010 | 2.323335 | 10,044,945 | Single middle recurrent slot | +| Fractal recurrent hybrid | 11L/512 SP1024, `AAAAAPAAAAA` | all-large-int8 | 1.356221 | 1.357619 | 2.292283 | 14,440,584 | Best recurrent export | + +These are single-seed matched comparisons (`seed=42`). We do not yet have a 3-seed confirmation for this exact headline configuration. + +The original 10-minute recurrent-primitive source run recorded post-quant BPB 1.376889 at 9,747,772 bytes under its initial export path. The requantization sweep in this folder is the cleaner like-for-like comparison because both the recurrent hybrid and pure attention were re-exported through the same variants. + +## Extended 60-Minute Probe + +I also ran a longer, non-competition-timed probe to see whether the recurrent primitive was still improving with more training time. This used schedule `AAAPAAAAPAA`, 60 minutes of training time after compile/prewarm, warmdown 4000, EMA 0.9965, mixed int6 clipsearch + zstd, and 1xH100. + +| Run | Steps | Train time | Pre BPB | Post BPB | Post loss | Bytes | Notes | +|---|---:|---:|---:|---:|---:|---:|---| +| Fractal recurrent hybrid 60-minute probe | 5,342 | 3,600.034s | 1.230186 | 1.241819 | 2.701189 | 16,179,345 | Out of the 10-minute rule and slightly over 16MB | + +This result is useful for research direction, not leaderboard scoring. It shows the recurrent primitive continued improving substantially with enough wall-clock time, but the current implementation is too slow and too large to translate that longer-run quality into the official 10-minute/16MB track without more systems work or multi-GPU scaling. + +## Interpretation + +The recurrent-primitive hybrid did not beat the pure-attention control. At high precision under the 16MB cap, it is about 0.0129 BPB behind the corresponding pure-attention export. The replacement also saved bytes in some variants, but the saved bytes were not enough to compensate for lower pre-quant quality and slower recurrent compute. + +The positive result is that the quantization failure mode is tractable. For the recurrent primitive, the post-minus-pre gap improved from about +0.0198 BPB under the default mixed export to +0.0014 BPB under all-large-int8 while remaining below 16MB. That suggests future work should focus on better insertion contracts and context/state leverage, not on raw one-for-one attention replacement. + +## Why This Is Useful + +This is a negative result with a reusable baseline: + +- The Fractal recurrent primitive is better treated as a side-channel, looped adapter, or context-state module than as a direct attention replacement. +- The recurrent/state matrices need quantization protection. +- Pure attention remains the stronger 10-minute small-context baseline on this surface. +- The next fair test is not "replace attention"; it is "keep the transformer stack and add recurrent state where it can buy context, TTT, or memory efficiency." + +## Included Files + +- `train_gpt.py`: training/evaluation script snapshot used for the recurrent experiments. +- Recurrent primitive runtime snapshot. +- Recurrent-primitive TTT adapter support snapshot. +- `supporting_files/requant_checkpoint_sweep.py`: helper used to re-export a trained checkpoint across quantization variants. +- `logs/source_summary_seed42.json`: source 10-minute recurrent checkpoint summary. +- `logs/quant_sweep_seed42.json`: recurrent checkpoint quantization sweep summary. +- `logs/quant_sweep_seed42.log`: recurrent checkpoint quantization sweep console log. +- `logs/baseline_quant_control_seed42.json`: pure-attention quantization control summary. +- `logs/baseline_quant_control_seed42.log`: pure-attention quantization control console log. +- `logs/extended_60min_summary_seed42.json`: longer recurrent-primitive trajectory probe, outside official time/size constraints. + +## Reproduction Notes + +The headline training run used the following intended contract: + +```bash +MODEL_FAMILY= \ +LAYER_SCHEDULE=AAAAAPAAAAA \ +RECURRENT_RUNTIME_BACKEND=triton \ +RECURRENT_STATE_BLOCKS=auto \ +RECURRENT_TTT_MODE=off \ +SPM_VOCAB_SIZE=1024 \ +TRAIN_TIME_LIMIT_SECONDS=600 \ +python train_gpt.py +``` + +The best reported recurrent-primitive export came from reusing that trained checkpoint and running the included quantization sweep, with `all_large_int8` selected as the best under-16MB export. + +## Non-Record Disclaimer + +This is intentionally submitted as a non-record PR. It is a controlled research note showing that a Fractal recurrent primitive can approach, but not beat, the pure-attention control under this particular 1xH100 10-minute contract. The contribution is the ablation, the code snapshot, and the evidence that future recurrent work should move toward side-channel/context-state insertion rather than direct block replacement. diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/baseline_quant_control_seed42.json b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/baseline_quant_control_seed42.json new file mode 100644 index 0000000000..3e312fd067 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/baseline_quant_control_seed42.json @@ -0,0 +1,309 @@ +{ + "best_post_quant_val_bpb": 1.3447236430030274, + "best_post_quant_val_loss": 2.270509543026506, + "best_total_submission_bytes_quantized": 14966424, + "best_variant": "all_large_int8", + "code_bytes": 231852, + "code_files": [ + { + "bytes": 160385, + "path": "train_gpt.py" + }, + { + "bytes": 41584, + "path": "p20_runtime.py" + }, + { + "bytes": 29883, + "path": "p20_ttt.py" + } + ], + "compiled_model": true, + "event": "checkpoint_requant_sweep_completed", + "experiment_kind": "step4_quantization_protection_sweep", + "mixed_int6_clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "mlp_hidden_dim": 1024, + "mlp_mult": 2, + "model_dim": 512, + "model_family": "baseline", + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 11, + "p20_block_pair_width_cap": 32, + "p20_layer_schedule": "", + "p20_runtime_backend": "torch", + "p20_state_blocks": "2", + "pre_quant_eval_time_ms": 52757.67647102475, + "pre_quant_val_bpb": 1.343710204513414, + "pre_quant_val_loss": 2.2687983945880066, + "quant_compression": "zstd", + "raw_model_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z/baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z_model.pt", + "run_id": "baseline_quant_sweep_seed42_20260412T095500Z", + "source_run_id": "baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z", + "source_summary_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z/summary.json", + "variants": [ + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [ + "attn", + "mlp" + ], + "name": "p20_int8_default", + "note": "Reproduce the Step 3 leader: attention/MLP int6, P20 large matrices int8, small tensors passthrough.", + "post_minus_pre_bpb": 0.016026586094396666, + "post_quant_eval_time_ms": 16081.944196950644, + "post_quant_val_bpb": 1.3597367906078106, + "post_quant_val_loss": 2.295858613881988, + "quant_file_bytes": 10062892, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_p20_int8_default.int6clip.zstd.ptz", + "quant_raw_bytes": 20936012, + "quant_stats": { + "baseline_tensor_bytes": 81889632, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.005724633359932341, + "int6_categories": [ + "attn", + "mlp" + ], + "int6_tensor_count": 66, + "int8_tensor_count": 1, + "num_float_tensors": 112, + "num_nonfloat_tensors": 0, + "num_tensors": 112, + "param_count": 20734552, + "passthrough_tensor_count": 45, + "quantized_payload_bytes": 20879712 + }, + "quant_time_ms": 20740.368217928335, + "total_submission_bytes_quantized": 10294744 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale", + "residual_scale", + "input_norm", + "output_norm", + "primitive.in_projection.projection.bias", + "primitive.state_transform_projection.bias", + "readout_projection.bias" + ], + "mixed_int6_categories": [ + "attn", + "mlp" + ], + "name": "p20_int8_p20_small_fp32", + "note": "Keep P20 residual/norm/bias tensors in fp32 instead of fp16 passthrough.", + "post_minus_pre_bpb": 0.016026586094396666, + "post_quant_eval_time_ms": 16104.23754202202, + "post_quant_val_bpb": 1.3597367906078106, + "post_quant_val_loss": 2.295858613881988, + "quant_file_bytes": 10062892, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_p20_int8_p20_small_fp32.int6clip.zstd.ptz", + "quant_raw_bytes": 20936012, + "quant_stats": { + "baseline_tensor_bytes": 81889632, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.005724633359932341, + "int6_categories": [ + "attn", + "mlp" + ], + "int6_tensor_count": 66, + "int8_tensor_count": 1, + "num_float_tensors": 112, + "num_nonfloat_tensors": 0, + "num_tensors": 112, + "param_count": 20734552, + "passthrough_tensor_count": 45, + "quantized_payload_bytes": 20879712 + }, + "quant_time_ms": 18807.846677023917, + "total_submission_bytes_quantized": 10294744 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [ + "mlp" + ], + "name": "attn_int8_mlp_int6_p20_int8", + "note": "Spend bytes protecting attention and P20 at int8 while leaving MLP at int6.", + "post_minus_pre_bpb": 0.011263235059893795, + "post_quant_eval_time_ms": 16145.048178965226, + "post_quant_val_bpb": 1.3549734395733077, + "post_quant_val_loss": 2.2878158951888956, + "quant_file_bytes": 11816510, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_attn_int8_mlp_int6_p20_int8.int6clip.zstd.ptz", + "quant_raw_bytes": 20936012, + "quant_stats": { + "baseline_tensor_bytes": 81889632, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.0019073950898018666, + "int6_categories": [ + "mlp" + ], + "int6_tensor_count": 22, + "int8_tensor_count": 45, + "num_float_tensors": 112, + "num_nonfloat_tensors": 0, + "num_tensors": 112, + "param_count": 20734552, + "passthrough_tensor_count": 45, + "quantized_payload_bytes": 20879712 + }, + "quant_time_ms": 12305.522667942569, + "total_submission_bytes_quantized": 12048362 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [ + "attn" + ], + "name": "attn_int6_mlp_int8_p20_int8", + "note": "Spend bytes protecting MLP and P20 at int8 while leaving attention at int6.", + "post_minus_pre_bpb": 0.005479778336315189, + "post_quant_eval_time_ms": 16160.725696012378, + "post_quant_val_bpb": 1.3491899828497291, + "post_quant_val_loss": 2.2780507707702893, + "quant_file_bytes": 12825617, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_attn_int6_mlp_int8_p20_int8.int6clip.zstd.ptz", + "quant_raw_bytes": 20936012, + "quant_stats": { + "baseline_tensor_bytes": 81889632, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.003817238270130474, + "int6_categories": [ + "attn" + ], + "int6_tensor_count": 44, + "int8_tensor_count": 23, + "num_float_tensors": 112, + "num_nonfloat_tensors": 0, + "num_tensors": 112, + "param_count": 20734552, + "passthrough_tensor_count": 45, + "quantized_payload_bytes": 20879712 + }, + "quant_time_ms": 12128.029583021998, + "total_submission_bytes_quantized": 13057469 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [], + "name": "all_large_int8", + "note": "Use int8 for all large tensors; this is the upper-bound quantization-protection check.", + "post_minus_pre_bpb": 0.0010134384896134385, + "post_quant_eval_time_ms": 16176.252738106996, + "post_quant_val_bpb": 1.3447236430030274, + "post_quant_val_loss": 2.270509543026506, + "quant_file_bytes": 14734572, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_all_large_int8.int6clip.zstd.ptz", + "quant_raw_bytes": 20935948, + "quant_stats": { + "baseline_tensor_bytes": 81889632, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.0, + "int6_categories": [], + "int6_tensor_count": 0, + "int8_tensor_count": 67, + "num_float_tensors": 112, + "num_nonfloat_tensors": 0, + "num_tensors": 112, + "param_count": 20734552, + "passthrough_tensor_count": 45, + "quantized_payload_bytes": 20879712 + }, + "quant_time_ms": 7733.910084003583, + "total_submission_bytes_quantized": 14966424 + } + ], + "vocab_size": 1024 +} diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/baseline_quant_control_seed42.log b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/baseline_quant_control_seed42.log new file mode 100644 index 0000000000..2708e12924 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/baseline_quant_control_seed42.log @@ -0,0 +1,11 @@ +/workspace/parameter-golf-tokenizer/parameter-golf/.codex/autoresearch/scripts/requant_checkpoint_sweep.py:157: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + state = torch.load(raw_model_path, map_location="cpu") +/workspace/parameter-golf-tokenizer/parameter-golf/.codex/autoresearch/scripts/requant_checkpoint_sweep.py:201: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + roundtrip_obj = torch.load( +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_default", "note": "Reproduce the Step 3 leader: attention/MLP int6, P20 large matrices int8, small tensors passthrough.", "post_minus_pre_bpb": 0.016026586094396666, "post_quant_eval_time_ms": 16081.944196950644, "post_quant_val_bpb": 1.3597367906078106, "post_quant_val_loss": 2.295858613881988, "quant_file_bytes": 10062892, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_p20_int8_default.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005724633359932341, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 66, "int8_tensor_count": 1, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 20740.368217928335, "total_submission_bytes_quantized": 10294744} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale", "residual_scale", "input_norm", "output_norm", "primitive.in_projection.projection.bias", "primitive.state_transform_projection.bias", "readout_projection.bias"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_p20_small_fp32", "note": "Keep P20 residual/norm/bias tensors in fp32 instead of fp16 passthrough.", "post_minus_pre_bpb": 0.016026586094396666, "post_quant_eval_time_ms": 16104.23754202202, "post_quant_val_bpb": 1.3597367906078106, "post_quant_val_loss": 2.295858613881988, "quant_file_bytes": 10062892, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_p20_int8_p20_small_fp32.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005724633359932341, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 66, "int8_tensor_count": 1, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 18807.846677023917, "total_submission_bytes_quantized": 10294744} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["mlp"], "name": "attn_int8_mlp_int6_p20_int8", "note": "Spend bytes protecting attention and P20 at int8 while leaving MLP at int6.", "post_minus_pre_bpb": 0.011263235059893795, "post_quant_eval_time_ms": 16145.048178965226, "post_quant_val_bpb": 1.3549734395733077, "post_quant_val_loss": 2.2878158951888956, "quant_file_bytes": 11816510, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_attn_int8_mlp_int6_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0019073950898018666, "int6_categories": ["mlp"], "int6_tensor_count": 22, "int8_tensor_count": 45, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 12305.522667942569, "total_submission_bytes_quantized": 12048362} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn"], "name": "attn_int6_mlp_int8_p20_int8", "note": "Spend bytes protecting MLP and P20 at int8 while leaving attention at int6.", "post_minus_pre_bpb": 0.005479778336315189, "post_quant_eval_time_ms": 16160.725696012378, "post_quant_val_bpb": 1.3491899828497291, "post_quant_val_loss": 2.2780507707702893, "quant_file_bytes": 12825617, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_attn_int6_mlp_int8_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.003817238270130474, "int6_categories": ["attn"], "int6_tensor_count": 44, "int8_tensor_count": 23, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 12128.029583021998, "total_submission_bytes_quantized": 13057469} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": [], "name": "all_large_int8", "note": "Use int8 for all large tensors; this is the upper-bound quantization-protection check.", "post_minus_pre_bpb": 0.0010134384896134385, "post_quant_eval_time_ms": 16176.252738106996, "post_quant_val_bpb": 1.3447236430030274, "post_quant_val_loss": 2.270509543026506, "quant_file_bytes": 14734572, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_all_large_int8.int6clip.zstd.ptz", "quant_raw_bytes": 20935948, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0, "int6_categories": [], "int6_tensor_count": 0, "int8_tensor_count": 67, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 7733.910084003583, "total_submission_bytes_quantized": 14966424} +summary_path .codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/summary.json +{"best_post_quant_val_bpb": 1.3447236430030274, "best_post_quant_val_loss": 2.270509543026506, "best_total_submission_bytes_quantized": 14966424, "best_variant": "all_large_int8", "code_bytes": 231852, "code_files": [{"bytes": 160385, "path": "train_gpt.py"}, {"bytes": 41584, "path": "p20_runtime.py"}, {"bytes": 29883, "path": "p20_ttt.py"}], "compiled_model": true, "event": "checkpoint_requant_sweep_completed", "experiment_kind": "step4_quantization_protection_sweep", "mixed_int6_clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "mlp_hidden_dim": 1024, "mlp_mult": 2, "model_dim": 512, "model_family": "baseline", "num_heads": 8, "num_kv_heads": 4, "num_layers": 11, "p20_block_pair_width_cap": 32, "p20_layer_schedule": "", "p20_runtime_backend": "torch", "p20_state_blocks": "2", "pre_quant_eval_time_ms": 52757.67647102475, "pre_quant_val_bpb": 1.343710204513414, "pre_quant_val_loss": 2.2687983945880066, "quant_compression": "zstd", "raw_model_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z/baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z_model.pt", "run_id": "baseline_quant_sweep_seed42_20260412T095500Z", "source_run_id": "baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z", "source_summary_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/baseline_sp1024_80shards_10min_gate_seed42_20260412T041056Z/summary.json", "variants": [{"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_default", "note": "Reproduce the Step 3 leader: attention/MLP int6, P20 large matrices int8, small tensors passthrough.", "post_minus_pre_bpb": 0.016026586094396666, "post_quant_eval_time_ms": 16081.944196950644, "post_quant_val_bpb": 1.3597367906078106, "post_quant_val_loss": 2.295858613881988, "quant_file_bytes": 10062892, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_p20_int8_default.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005724633359932341, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 66, "int8_tensor_count": 1, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 20740.368217928335, "total_submission_bytes_quantized": 10294744}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale", "residual_scale", "input_norm", "output_norm", "primitive.in_projection.projection.bias", "primitive.state_transform_projection.bias", "readout_projection.bias"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_p20_small_fp32", "note": "Keep P20 residual/norm/bias tensors in fp32 instead of fp16 passthrough.", "post_minus_pre_bpb": 0.016026586094396666, "post_quant_eval_time_ms": 16104.23754202202, "post_quant_val_bpb": 1.3597367906078106, "post_quant_val_loss": 2.295858613881988, "quant_file_bytes": 10062892, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_p20_int8_p20_small_fp32.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005724633359932341, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 66, "int8_tensor_count": 1, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 18807.846677023917, "total_submission_bytes_quantized": 10294744}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["mlp"], "name": "attn_int8_mlp_int6_p20_int8", "note": "Spend bytes protecting attention and P20 at int8 while leaving MLP at int6.", "post_minus_pre_bpb": 0.011263235059893795, "post_quant_eval_time_ms": 16145.048178965226, "post_quant_val_bpb": 1.3549734395733077, "post_quant_val_loss": 2.2878158951888956, "quant_file_bytes": 11816510, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_attn_int8_mlp_int6_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0019073950898018666, "int6_categories": ["mlp"], "int6_tensor_count": 22, "int8_tensor_count": 45, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 12305.522667942569, "total_submission_bytes_quantized": 12048362}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn"], "name": "attn_int6_mlp_int8_p20_int8", "note": "Spend bytes protecting MLP and P20 at int8 while leaving attention at int6.", "post_minus_pre_bpb": 0.005479778336315189, "post_quant_eval_time_ms": 16160.725696012378, "post_quant_val_bpb": 1.3491899828497291, "post_quant_val_loss": 2.2780507707702893, "quant_file_bytes": 12825617, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_attn_int6_mlp_int8_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 20936012, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.003817238270130474, "int6_categories": ["attn"], "int6_tensor_count": 44, "int8_tensor_count": 23, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 12128.029583021998, "total_submission_bytes_quantized": 13057469}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": [], "name": "all_large_int8", "note": "Use int8 for all large tensors; this is the upper-bound quantization-protection check.", "post_minus_pre_bpb": 0.0010134384896134385, "post_quant_eval_time_ms": 16176.252738106996, "post_quant_val_bpb": 1.3447236430030274, "post_quant_val_loss": 2.270509543026506, "quant_file_bytes": 14734572, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/baseline_quant_sweep_seed42_20260412T095500Z/baseline_quant_sweep_seed42_20260412T095500Z_all_large_int8.int6clip.zstd.ptz", "quant_raw_bytes": 20935948, "quant_stats": {"baseline_tensor_bytes": 81889632, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0, "int6_categories": [], "int6_tensor_count": 0, "int8_tensor_count": 67, "num_float_tensors": 112, "num_nonfloat_tensors": 0, "num_tensors": 112, "param_count": 20734552, "passthrough_tensor_count": 45, "quantized_payload_bytes": 20879712}, "quant_time_ms": 7733.910084003583, "total_submission_bytes_quantized": 14966424}], "vocab_size": 1024} diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/extended_60min_summary_seed42.json b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/extended_60min_summary_seed42.json new file mode 100644 index 0000000000..9dcc6413c4 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/extended_60min_summary_seed42.json @@ -0,0 +1,188 @@ +{ + "baseline_loop_layer_indices": null, + "baseline_loop_repeats": null, + "beta1": 0.9, + "beta2": 0.95, + "code_bytes": 222006, + "code_files": [ + { + "bytes": 150539, + "path": "train_gpt.py" + }, + { + "bytes": 41584, + "path": "p20_runtime.py" + }, + { + "bytes": 29883, + "path": "p20_ttt.py" + } + ], + "compile_prewarm_counts_toward_training": false, + "compile_prewarm_data": "synthetic", + "compile_prewarm_state_restored": true, + "compile_prewarm_steps_executed": 1, + "compile_prewarm_steps_requested": 1, + "compile_prewarm_time_ms": 53096.58326301724, + "compiled_model": true, + "compiled_muon_backend": true, + "ema_applied": true, + "ema_decay": 0.9965, + "ema_enabled": true, + "ema_eval_time_ms": 14744.602988939732, + "ema_start_step": 0, + "embed_lr": 0.6, + "eval_stride": null, + "final_step": 5342, + "grad_accum_steps": 8, + "head_lr": 0.008, + "int8_zlib_bytes": 15957339, + "iterations": 100000, + "loop_delta_scale": 1.0, + "mamba_chunk_size": null, + "mamba_d_state": null, + "mamba_expand": null, + "mamba_include_mlp": null, + "mamba_is_mimo": null, + "mamba_layer_schedule": null, + "mamba_mimo_rank": null, + "mamba_ngroups": null, + "mamba_residual_scale_init": null, + "matrix_lr": 0.04, + "mixed_int6_categories": [ + "attn", + "mlp", + "p20" + ], + "mixed_int6_clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "mlp_mult": 2, + "model_dim": 512, + "model_family": "p20_hybrid", + "model_parameter_count": 22639176, + "non_p20_matrix_parameter_count": 16515072, + "non_p20_scalar_parameter_count": 18504, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 11, + "optimizer_family": "muon", + "p20_block_pair_width_cap": 32, + "p20_layer_schedule": "AAAPAAAAPAA", + "p20_loop_repeats": 1, + "p20_matrix_lr": 0.04, + "p20_matrix_parameter_count": 4521984, + "p20_primitive_loop_delta_scale": 1.0, + "p20_primitive_loop_repeats": 1, + "p20_primitive_loop_repeats_by_p": [ + 1, + 1 + ], + "p20_runtime_backend": "triton", + "p20_scalar_lr": 0.04, + "p20_scalar_parameter_count": 10752, + "p20_scan_chunk_size": 0, + "p20_state_blocks": 8, + "p20_state_blocks_requested": "auto", + "p20_ttt_beta1": 0.0, + "p20_ttt_beta2": 0.999, + "p20_ttt_bootstrap_samples": 2000, + "p20_ttt_bootstrap_seed": 12345, + "p20_ttt_chunk_size": 64, + "p20_ttt_context_size": 1024, + "p20_ttt_grad_clip_norm": 0.0, + "p20_ttt_lr": 0.01, + "p20_ttt_max_docs": null, + "p20_ttt_mode": "off", + "p20_ttt_per_doc_path": null, + "p20_ttt_weight_decay": 0.0, + "post_quant_eval_time_ms": 14743.98337607272, + "post_quant_p20_doc_chunks_eval_time_ms": null, + "post_quant_p20_doc_chunks_stats": null, + "post_quant_p20_doc_chunks_val_bpb": null, + "post_quant_p20_doc_chunks_val_loss": null, + "post_quant_p20_ttt_eval_time_ms": null, + "post_quant_p20_ttt_paired_stats": null, + "post_quant_p20_ttt_stats": null, + "post_quant_p20_ttt_val_bpb": null, + "post_quant_p20_ttt_val_loss": null, + "post_quant_sliding_eval_time_ms": null, + "post_quant_sliding_val_bpb": null, + "post_quant_sliding_val_loss": null, + "post_quant_val_bpb": 1.2418190170016838, + "post_quant_val_loss": 2.7011890392867377, + "pre_quant_p20_doc_chunks_eval_time_ms": null, + "pre_quant_p20_doc_chunks_stats": null, + "pre_quant_p20_doc_chunks_val_bpb": null, + "pre_quant_p20_doc_chunks_val_loss": null, + "pre_quant_p20_ttt_eval_time_ms": null, + "pre_quant_p20_ttt_paired_stats": null, + "pre_quant_p20_ttt_stats": null, + "pre_quant_p20_ttt_val_bpb": null, + "pre_quant_p20_ttt_val_loss": null, + "pre_quant_sliding_eval_time_ms": null, + "pre_quant_sliding_val_bpb": null, + "pre_quant_sliding_val_loss": null, + "pre_quant_val_bpb": 1.230186110436162, + "pre_quant_val_loss": 2.675885287870768, + "quant_compression": "zstd", + "quant_file_bytes": 15957339, + "quant_format": "mixed_int6_clipsearch", + "quant_raw_bytes": 22905290, + "quant_stats": { + "baseline_tensor_bytes": 83609888, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.010233231827442069, + "int6_categories": [ + "attn", + "mlp", + "p20" + ], + "int6_tensor_count": 62, + "int8_tensor_count": 1, + "num_float_tensors": 117, + "num_nonfloat_tensors": 0, + "num_tensors": 117, + "param_count": 22639176, + "passthrough_tensor_count": 54, + "quantized_payload_bytes": 22847776 + }, + "raw_model_bytes": 83673200, + "recurrence_activated_step": null, + "recurrence_activated_train_time_ms": null, + "recurrence_aux_applications": null, + "recurrence_aux_kind": null, + "recurrence_aux_mode": null, + "recurrence_aux_scale_init": null, + "recurrence_decoder_schedule": null, + "recurrence_encoder_schedule": null, + "recurrence_start_fraction": null, + "recurrence_start_step": null, + "run_id": "p20_AAAPAAAAPAA_compileprewarm_offclock_60min_wd4000_ema09965_int6clip_zstd_clip1_seed42_20260412T014211Z", + "scalar_lr": 0.04, + "stopped_early": true, + "sw_eval_batch": 32, + "tied_embed_lr": 0.05, + "total_submission_bytes_int8_zlib": 16179345, + "total_submission_bytes_quantized": 16179345, + "total_submission_bytes_raw": 83895206, + "train_batch_tokens": 524288, + "train_seq_len": 1024, + "train_time_ms": 3600033.7083761115, + "val_batch_size": 524288, + "val_max_batches": null, + "vocab_size": 3072, + "warmup_steps_executed": 0, + "warmup_time_ms": 0.0, + "world_size": 1 +} diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/quant_sweep_seed42.json b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/quant_sweep_seed42.json new file mode 100644 index 0000000000..5649658306 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/quant_sweep_seed42.json @@ -0,0 +1,309 @@ +{ + "best_post_quant_val_bpb": 1.357619300696428, + "best_post_quant_val_loss": 2.2922833208646662, + "best_total_submission_bytes_quantized": 14440584, + "best_variant": "all_large_int8", + "code_bytes": 231852, + "code_files": [ + { + "bytes": 160385, + "path": "train_gpt.py" + }, + { + "bytes": 41584, + "path": "p20_runtime.py" + }, + { + "bytes": 29883, + "path": "p20_ttt.py" + } + ], + "compiled_model": true, + "event": "checkpoint_requant_sweep_completed", + "experiment_kind": "step4_quantization_protection_sweep", + "mixed_int6_clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "mlp_hidden_dim": 1024, + "mlp_mult": 2, + "model_dim": 512, + "model_family": "p20_hybrid", + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 11, + "p20_block_pair_width_cap": 32, + "p20_layer_schedule": "AAAAAPAAAAA", + "p20_runtime_backend": "triton", + "p20_state_blocks": "auto", + "pre_quant_eval_time_ms": 52896.35168085806, + "pre_quant_val_bpb": 1.3562212709803687, + "pre_quant_val_loss": 2.2899228062501846, + "quant_compression": "zstd", + "raw_model_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z/p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z_model.pt", + "run_id": "p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z", + "source_run_id": "p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z", + "source_summary_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z/summary.json", + "variants": [ + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [ + "attn", + "mlp" + ], + "name": "p20_int8_default", + "note": "Reproduce the Step 3 leader: attention/MLP int6, P20 large matrices int8, small tensors passthrough.", + "post_minus_pre_bpb": 0.019788843149066704, + "post_quant_eval_time_ms": 16966.785185039043, + "post_quant_val_bpb": 1.3760101141294354, + "post_quant_val_loss": 2.323335438986433, + "quant_file_bytes": 9813093, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_p20_int8_default.int6clip.zstd.ptz", + "quant_raw_bytes": 21389068, + "quant_stats": { + "baseline_tensor_bytes": 81696064, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.005375599743274506, + "int6_categories": [ + "attn", + "mlp" + ], + "int6_tensor_count": 62, + "int8_tensor_count": 3, + "num_float_tensors": 114, + "num_nonfloat_tensors": 0, + "num_tensors": 114, + "param_count": 21161296, + "passthrough_tensor_count": 49, + "quantized_payload_bytes": 21332288 + }, + "quant_time_ms": 21683.49984800443, + "total_submission_bytes_quantized": 10044945 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale", + "residual_scale", + "input_norm", + "output_norm", + "primitive.in_projection.projection.bias", + "primitive.state_transform_projection.bias", + "readout_projection.bias" + ], + "mixed_int6_categories": [ + "attn", + "mlp" + ], + "name": "p20_int8_p20_small_fp32", + "note": "Keep P20 residual/norm/bias tensors in fp32 instead of fp16 passthrough.", + "post_minus_pre_bpb": 0.01978966339088739, + "post_quant_eval_time_ms": 16966.824583010748, + "post_quant_val_bpb": 1.376010934371256, + "post_quant_val_loss": 2.323336823930389, + "quant_file_bytes": 9834026, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_p20_int8_p20_small_fp32.int6clip.zstd.ptz", + "quant_raw_bytes": 21399628, + "quant_stats": { + "baseline_tensor_bytes": 81696064, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.005375599743274506, + "int6_categories": [ + "attn", + "mlp" + ], + "int6_tensor_count": 62, + "int8_tensor_count": 3, + "num_float_tensors": 114, + "num_nonfloat_tensors": 0, + "num_tensors": 114, + "param_count": 21161296, + "passthrough_tensor_count": 49, + "quantized_payload_bytes": 21343040 + }, + "quant_time_ms": 20818.484775023535, + "total_submission_bytes_quantized": 10065878 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [ + "mlp" + ], + "name": "attn_int8_mlp_int6_p20_int8", + "note": "Spend bytes protecting attention and P20 at int8 while leaving MLP at int6.", + "post_minus_pre_bpb": 0.013572220075183772, + "post_quant_eval_time_ms": 16995.231362059712, + "post_quant_val_bpb": 1.3697934910555525, + "post_quant_val_loss": 2.312838931329939, + "quant_file_bytes": 11586467, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_attn_int8_mlp_int6_p20_int8.int6clip.zstd.ptz", + "quant_raw_bytes": 21389004, + "quant_stats": { + "baseline_tensor_bytes": 81696064, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.0019072179202339612, + "int6_categories": [ + "mlp" + ], + "int6_tensor_count": 22, + "int8_tensor_count": 43, + "num_float_tensors": 114, + "num_nonfloat_tensors": 0, + "num_tensors": 114, + "param_count": 21161296, + "passthrough_tensor_count": 49, + "quantized_payload_bytes": 21332288 + }, + "quant_time_ms": 14076.096609933302, + "total_submission_bytes_quantized": 11818319 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [ + "attn" + ], + "name": "attn_int6_mlp_int8_p20_int8", + "note": "Spend bytes protecting MLP and P20 at int8 while leaving attention at int6.", + "post_minus_pre_bpb": 0.007291756789034487, + "post_quant_eval_time_ms": 16998.020959086716, + "post_quant_val_bpb": 1.3635130277694032, + "post_quant_val_loss": 2.302234632149191, + "quant_file_bytes": 12307455, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_attn_int6_mlp_int8_p20_int8.int6clip.zstd.ptz", + "quant_raw_bytes": 21389004, + "quant_stats": { + "baseline_tensor_bytes": 81696064, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.003468381823040545, + "int6_categories": [ + "attn" + ], + "int6_tensor_count": 40, + "int8_tensor_count": 25, + "num_float_tensors": 114, + "num_nonfloat_tensors": 0, + "num_tensors": 114, + "param_count": 21161296, + "passthrough_tensor_count": 49, + "quantized_payload_bytes": 21332288 + }, + "quant_time_ms": 14366.346410941333, + "total_submission_bytes_quantized": 12539307 + }, + { + "fp32_patterns": [ + "attn_scale", + "attn_scales", + "mlp_scale", + "mlp_scales", + "resid_mix", + "resid_mixes", + "q_gain", + "skip_weight", + "skip_weights", + "bigram.scale" + ], + "mixed_int6_categories": [], + "name": "all_large_int8", + "note": "Use int8 for all large tensors; this is the upper-bound quantization-protection check.", + "post_minus_pre_bpb": 0.0013980297160594013, + "post_quant_eval_time_ms": 17020.262690028176, + "post_quant_val_bpb": 1.357619300696428, + "post_quant_val_loss": 2.2922833208646662, + "quant_file_bytes": 14208732, + "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_all_large_int8.int6clip.zstd.ptz", + "quant_raw_bytes": 21389004, + "quant_stats": { + "baseline_tensor_bytes": 81696064, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.0, + "int6_categories": [], + "int6_tensor_count": 0, + "int8_tensor_count": 65, + "num_float_tensors": 114, + "num_nonfloat_tensors": 0, + "num_tensors": 114, + "param_count": 21161296, + "passthrough_tensor_count": 49, + "quantized_payload_bytes": 21332288 + }, + "quant_time_ms": 11101.043669972569, + "total_submission_bytes_quantized": 14440584 + } + ], + "vocab_size": 1024 +} diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/quant_sweep_seed42.log b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/quant_sweep_seed42.log new file mode 100644 index 0000000000..d7e7e3a842 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/quant_sweep_seed42.log @@ -0,0 +1,11 @@ +/workspace/parameter-golf-tokenizer/parameter-golf/.codex/autoresearch/scripts/requant_checkpoint_sweep.py:157: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + state = torch.load(raw_model_path, map_location="cpu") +/workspace/parameter-golf-tokenizer/parameter-golf/.codex/autoresearch/scripts/requant_checkpoint_sweep.py:201: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + roundtrip_obj = torch.load( +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_default", "note": "Reproduce the Step 3 leader: attention/MLP int6, P20 large matrices int8, small tensors passthrough.", "post_minus_pre_bpb": 0.019788843149066704, "post_quant_eval_time_ms": 16966.785185039043, "post_quant_val_bpb": 1.3760101141294354, "post_quant_val_loss": 2.323335438986433, "quant_file_bytes": 9813093, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_p20_int8_default.int6clip.zstd.ptz", "quant_raw_bytes": 21389068, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005375599743274506, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 62, "int8_tensor_count": 3, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 21683.49984800443, "total_submission_bytes_quantized": 10044945} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale", "residual_scale", "input_norm", "output_norm", "primitive.in_projection.projection.bias", "primitive.state_transform_projection.bias", "readout_projection.bias"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_p20_small_fp32", "note": "Keep P20 residual/norm/bias tensors in fp32 instead of fp16 passthrough.", "post_minus_pre_bpb": 0.01978966339088739, "post_quant_eval_time_ms": 16966.824583010748, "post_quant_val_bpb": 1.376010934371256, "post_quant_val_loss": 2.323336823930389, "quant_file_bytes": 9834026, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_p20_int8_p20_small_fp32.int6clip.zstd.ptz", "quant_raw_bytes": 21399628, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005375599743274506, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 62, "int8_tensor_count": 3, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21343040}, "quant_time_ms": 20818.484775023535, "total_submission_bytes_quantized": 10065878} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["mlp"], "name": "attn_int8_mlp_int6_p20_int8", "note": "Spend bytes protecting attention and P20 at int8 while leaving MLP at int6.", "post_minus_pre_bpb": 0.013572220075183772, "post_quant_eval_time_ms": 16995.231362059712, "post_quant_val_bpb": 1.3697934910555525, "post_quant_val_loss": 2.312838931329939, "quant_file_bytes": 11586467, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_attn_int8_mlp_int6_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 21389004, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0019072179202339612, "int6_categories": ["mlp"], "int6_tensor_count": 22, "int8_tensor_count": 43, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 14076.096609933302, "total_submission_bytes_quantized": 11818319} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn"], "name": "attn_int6_mlp_int8_p20_int8", "note": "Spend bytes protecting MLP and P20 at int8 while leaving attention at int6.", "post_minus_pre_bpb": 0.007291756789034487, "post_quant_eval_time_ms": 16998.020959086716, "post_quant_val_bpb": 1.3635130277694032, "post_quant_val_loss": 2.302234632149191, "quant_file_bytes": 12307455, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_attn_int6_mlp_int8_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 21389004, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.003468381823040545, "int6_categories": ["attn"], "int6_tensor_count": 40, "int8_tensor_count": 25, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 14366.346410941333, "total_submission_bytes_quantized": 12539307} +variant {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": [], "name": "all_large_int8", "note": "Use int8 for all large tensors; this is the upper-bound quantization-protection check.", "post_minus_pre_bpb": 0.0013980297160594013, "post_quant_eval_time_ms": 17020.262690028176, "post_quant_val_bpb": 1.357619300696428, "post_quant_val_loss": 2.2922833208646662, "quant_file_bytes": 14208732, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_all_large_int8.int6clip.zstd.ptz", "quant_raw_bytes": 21389004, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0, "int6_categories": [], "int6_tensor_count": 0, "int8_tensor_count": 65, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 11101.043669972569, "total_submission_bytes_quantized": 14440584} +summary_path .codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/summary.json +{"best_post_quant_val_bpb": 1.357619300696428, "best_post_quant_val_loss": 2.2922833208646662, "best_total_submission_bytes_quantized": 14440584, "best_variant": "all_large_int8", "code_bytes": 231852, "code_files": [{"bytes": 160385, "path": "train_gpt.py"}, {"bytes": 41584, "path": "p20_runtime.py"}, {"bytes": 29883, "path": "p20_ttt.py"}], "compiled_model": true, "event": "checkpoint_requant_sweep_completed", "experiment_kind": "step4_quantization_protection_sweep", "mixed_int6_clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "mlp_hidden_dim": 1024, "mlp_mult": 2, "model_dim": 512, "model_family": "p20_hybrid", "num_heads": 8, "num_kv_heads": 4, "num_layers": 11, "p20_block_pair_width_cap": 32, "p20_layer_schedule": "AAAAAPAAAAA", "p20_runtime_backend": "triton", "p20_state_blocks": "auto", "pre_quant_eval_time_ms": 52896.35168085806, "pre_quant_val_bpb": 1.3562212709803687, "pre_quant_val_loss": 2.2899228062501846, "quant_compression": "zstd", "raw_model_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z/p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z_model.pt", "run_id": "p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z", "source_run_id": "p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z", "source_summary_path": ".codex/autoresearch/results_cuda_canonical_sp1024_check/p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z/summary.json", "variants": [{"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_default", "note": "Reproduce the Step 3 leader: attention/MLP int6, P20 large matrices int8, small tensors passthrough.", "post_minus_pre_bpb": 0.019788843149066704, "post_quant_eval_time_ms": 16966.785185039043, "post_quant_val_bpb": 1.3760101141294354, "post_quant_val_loss": 2.323335438986433, "quant_file_bytes": 9813093, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_p20_int8_default.int6clip.zstd.ptz", "quant_raw_bytes": 21389068, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005375599743274506, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 62, "int8_tensor_count": 3, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 21683.49984800443, "total_submission_bytes_quantized": 10044945}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale", "residual_scale", "input_norm", "output_norm", "primitive.in_projection.projection.bias", "primitive.state_transform_projection.bias", "readout_projection.bias"], "mixed_int6_categories": ["attn", "mlp"], "name": "p20_int8_p20_small_fp32", "note": "Keep P20 residual/norm/bias tensors in fp32 instead of fp16 passthrough.", "post_minus_pre_bpb": 0.01978966339088739, "post_quant_eval_time_ms": 16966.824583010748, "post_quant_val_bpb": 1.376010934371256, "post_quant_val_loss": 2.323336823930389, "quant_file_bytes": 9834026, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_p20_int8_p20_small_fp32.int6clip.zstd.ptz", "quant_raw_bytes": 21399628, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.005375599743274506, "int6_categories": ["attn", "mlp"], "int6_tensor_count": 62, "int8_tensor_count": 3, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21343040}, "quant_time_ms": 20818.484775023535, "total_submission_bytes_quantized": 10065878}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["mlp"], "name": "attn_int8_mlp_int6_p20_int8", "note": "Spend bytes protecting attention and P20 at int8 while leaving MLP at int6.", "post_minus_pre_bpb": 0.013572220075183772, "post_quant_eval_time_ms": 16995.231362059712, "post_quant_val_bpb": 1.3697934910555525, "post_quant_val_loss": 2.312838931329939, "quant_file_bytes": 11586467, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_attn_int8_mlp_int6_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 21389004, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0019072179202339612, "int6_categories": ["mlp"], "int6_tensor_count": 22, "int8_tensor_count": 43, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 14076.096609933302, "total_submission_bytes_quantized": 11818319}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": ["attn"], "name": "attn_int6_mlp_int8_p20_int8", "note": "Spend bytes protecting MLP and P20 at int8 while leaving attention at int6.", "post_minus_pre_bpb": 0.007291756789034487, "post_quant_eval_time_ms": 16998.020959086716, "post_quant_val_bpb": 1.3635130277694032, "post_quant_val_loss": 2.302234632149191, "quant_file_bytes": 12307455, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_attn_int6_mlp_int8_p20_int8.int6clip.zstd.ptz", "quant_raw_bytes": 21389004, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.003468381823040545, "int6_categories": ["attn"], "int6_tensor_count": 40, "int8_tensor_count": 25, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 14366.346410941333, "total_submission_bytes_quantized": 12539307}, {"fp32_patterns": ["attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "bigram.scale"], "mixed_int6_categories": [], "name": "all_large_int8", "note": "Use int8 for all large tensors; this is the upper-bound quantization-protection check.", "post_minus_pre_bpb": 0.0013980297160594013, "post_quant_eval_time_ms": 17020.262690028176, "post_quant_val_bpb": 1.357619300696428, "post_quant_val_loss": 2.2922833208646662, "quant_file_bytes": 14208732, "quant_path": ".codex/autoresearch/results_cuda_step4_quant_sweep/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z/p20_AAAAAPAAAAA_quant_sweep_seed42_20260412T094500Z_all_large_int8.int6clip.zstd.ptz", "quant_raw_bytes": 21389004, "quant_stats": {"baseline_tensor_bytes": 81696064, "clip_quantiles": [0.999, 0.9995, 0.9999, 0.99999, 1.0], "clipsearch_mse_sum": 0.0, "int6_categories": [], "int6_tensor_count": 0, "int8_tensor_count": 65, "num_float_tensors": 114, "num_nonfloat_tensors": 0, "num_tensors": 114, "param_count": 21161296, "passthrough_tensor_count": 49, "quantized_payload_bytes": 21332288}, "quant_time_ms": 11101.043669972569, "total_submission_bytes_quantized": 14440584}], "vocab_size": 1024} diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/source_summary_seed42.json b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/source_summary_seed42.json new file mode 100644 index 0000000000..187f71ab04 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/logs/source_summary_seed42.json @@ -0,0 +1,191 @@ +{ + "baseline_loop_layer_indices": null, + "baseline_loop_repeats": null, + "beta1": 0.9, + "beta2": 0.95, + "code_bytes": 223725, + "code_files": [ + { + "bytes": 152258, + "path": "train_gpt.py" + }, + { + "bytes": 41584, + "path": "p20_runtime.py" + }, + { + "bytes": 29883, + "path": "p20_ttt.py" + } + ], + "compile_prewarm_counts_toward_training": false, + "compile_prewarm_data": "synthetic", + "compile_prewarm_state_restored": true, + "compile_prewarm_steps_executed": 1, + "compile_prewarm_steps_requested": 1, + "compile_prewarm_time_ms": 56584.48242209852, + "compiled_model": true, + "compiled_muon_backend": true, + "ema_applied": true, + "ema_decay": 0.9965, + "ema_enabled": true, + "ema_eval_time_ms": 17049.33060402982, + "ema_start_step": 0, + "embed_lr": 0.6, + "eval_stride": null, + "final_step": 988, + "grad_accum_steps": 8, + "head_lr": 0.008, + "int8_zlib_bytes": 9524047, + "iterations": 100000, + "loop_delta_scale": 1.0, + "mamba_chunk_size": null, + "mamba_d_state": null, + "mamba_expand": null, + "mamba_include_mlp": null, + "mamba_is_mimo": null, + "mamba_layer_schedule": null, + "mamba_mimo_rank": null, + "mamba_ngroups": null, + "mamba_residual_scale_init": null, + "matrix_lr": 0.04, + "min_train_shards": 80, + "min_train_tokens": 8000000000, + "mixed_int6_categories": [ + "attn", + "mlp", + "p20" + ], + "mixed_int6_clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "mlp_mult": 2, + "model_dim": 512, + "model_family": "p20_hybrid", + "model_parameter_count": 21161296, + "non_p20_matrix_parameter_count": 18350080, + "non_p20_scalar_parameter_count": 20560, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 11, + "optimizer_family": "muon", + "p20_block_pair_width_cap": 32, + "p20_layer_schedule": "AAAAAPAAAAA", + "p20_loop_repeats": 1, + "p20_matrix_lr": 0.04, + "p20_matrix_parameter_count": 2260992, + "p20_primitive_loop_delta_scale": 1.0, + "p20_primitive_loop_repeats": 1, + "p20_primitive_loop_repeats_by_p": [ + 1 + ], + "p20_runtime_backend": "triton", + "p20_scalar_lr": 0.04, + "p20_scalar_parameter_count": 5376, + "p20_scan_chunk_size": 0, + "p20_state_blocks": 8, + "p20_state_blocks_requested": "auto", + "p20_ttt_beta1": 0.0, + "p20_ttt_beta2": 0.999, + "p20_ttt_bootstrap_samples": 2000, + "p20_ttt_bootstrap_seed": 12345, + "p20_ttt_chunk_size": 64, + "p20_ttt_context_size": 1024, + "p20_ttt_grad_clip_norm": 0.0, + "p20_ttt_lr": 0.01, + "p20_ttt_max_docs": null, + "p20_ttt_mode": "off", + "p20_ttt_per_doc_path": null, + "p20_ttt_weight_decay": 0.0, + "post_quant_eval_time_ms": 16966.56051906757, + "post_quant_p20_doc_chunks_eval_time_ms": null, + "post_quant_p20_doc_chunks_stats": null, + "post_quant_p20_doc_chunks_val_bpb": null, + "post_quant_p20_doc_chunks_val_loss": null, + "post_quant_p20_ttt_eval_time_ms": null, + "post_quant_p20_ttt_paired_stats": null, + "post_quant_p20_ttt_stats": null, + "post_quant_p20_ttt_val_bpb": null, + "post_quant_p20_ttt_val_loss": null, + "post_quant_sliding_eval_time_ms": null, + "post_quant_sliding_val_bpb": null, + "post_quant_sliding_val_loss": null, + "post_quant_val_bpb": 1.376889176496691, + "post_quant_val_loss": 2.3248196989711185, + "pre_quant_p20_doc_chunks_eval_time_ms": null, + "pre_quant_p20_doc_chunks_stats": null, + "pre_quant_p20_doc_chunks_val_bpb": null, + "pre_quant_p20_doc_chunks_val_loss": null, + "pre_quant_p20_ttt_eval_time_ms": null, + "pre_quant_p20_ttt_paired_stats": null, + "pre_quant_p20_ttt_stats": null, + "pre_quant_p20_ttt_val_bpb": null, + "pre_quant_p20_ttt_val_loss": null, + "pre_quant_sliding_eval_time_ms": null, + "pre_quant_sliding_val_bpb": null, + "pre_quant_sliding_val_loss": null, + "pre_quant_val_bpb": 1.3562212709803687, + "pre_quant_val_loss": 2.2899228062501846, + "quant_compression": "zstd", + "quant_file_bytes": 9524047, + "quant_format": "mixed_int6_clipsearch", + "quant_raw_bytes": 21389068, + "quant_stats": { + "baseline_tensor_bytes": 81696064, + "clip_quantiles": [ + 0.999, + 0.9995, + 0.9999, + 0.99999, + 1.0 + ], + "clipsearch_mse_sum": 0.00554943169845501, + "int6_categories": [ + "attn", + "mlp", + "p20" + ], + "int6_tensor_count": 64, + "int8_tensor_count": 1, + "num_float_tensors": 114, + "num_nonfloat_tensors": 0, + "num_tensors": 114, + "param_count": 21161296, + "passthrough_tensor_count": 49, + "quantized_payload_bytes": 21332288 + }, + "raw_model_bytes": 81752452, + "recurrence_activated_step": null, + "recurrence_activated_train_time_ms": null, + "recurrence_aux_applications": null, + "recurrence_aux_kind": null, + "recurrence_aux_mode": null, + "recurrence_aux_scale_init": null, + "recurrence_decoder_schedule": null, + "recurrence_encoder_schedule": null, + "recurrence_start_fraction": null, + "recurrence_start_step": null, + "run_id": "p20_AAAAAPAAAAA_sp1024_80shards_10min_gate_seed42_20260412T043154Z", + "scalar_lr": 0.04, + "stopped_early": true, + "sw_eval_batch": 32, + "tied_embed_lr": 0.05, + "total_submission_bytes_int8_zlib": 9747772, + "total_submission_bytes_quantized": 9747772, + "total_submission_bytes_raw": 81976177, + "train_batch_tokens": 524288, + "train_seq_len": 1024, + "train_shards": 80, + "train_time_ms": 600491.2087989505, + "train_tokens": 8000000000, + "val_batch_size": 524288, + "val_max_batches": null, + "vocab_size": 1024, + "warmup_steps_executed": 0, + "warmup_time_ms": 0.0, + "world_size": 1 +} diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/p20_runtime.py b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/p20_runtime.py new file mode 100644 index 0000000000..ddb0ba16a8 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/p20_runtime.py @@ -0,0 +1,974 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +try: + import triton + import triton.language as tl +except Exception: + triton = None + tl = None + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class PackedLinearProjection(nn.Module): + def __init__(self, d_in: int, split_sizes: tuple[int, ...], *, bias: bool = True) -> None: + super().__init__() + if not split_sizes or any(size <= 0 for size in split_sizes): + raise ValueError(f"invalid packed projection split sizes: {split_sizes}") + self.split_sizes = split_sizes + self.projection = CastedLinear(d_in, sum(split_sizes), bias=bias) + + def forward(self, inputs: Tensor) -> tuple[Tensor, ...]: + packed = self.projection(inputs) + return tuple(packed.split(self.split_sizes, dim=-1)) + + +class BlockDiagonalLinear(nn.Module): + def __init__(self, width: int, blocks: int) -> None: + super().__init__() + if width % blocks != 0: + raise ValueError(f"block-diagonal width {width} must be divisible by blocks {blocks}") + self.width = width + self.blocks = blocks + self.block_width = width // blocks + # Keep this as a 2D parameter so the existing Muon split can treat it as a matrix. + self.weight = nn.Parameter(torch.empty(width, self.block_width)) + self.bias = nn.Parameter(torch.empty(width)) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=5**0.5) + bound = 1.0 / self.block_width**0.5 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, inputs: Tensor) -> Tensor: + leading = inputs.shape[:-1] + reshaped_inputs = inputs.reshape(-1, self.blocks, self.block_width) + reshaped_weight = self.weight.reshape(self.blocks, self.block_width, self.block_width) + projected = torch.einsum("bgi,goi->bgo", reshaped_inputs, reshaped_weight) + projected = projected.reshape(*leading, self.width) + return projected + self.bias.to(dtype=projected.dtype) + + +def triton_runtime_available() -> bool: + return triton is not None and tl is not None + + +def ensure_triton_runtime_available() -> None: + if not triton_runtime_available(): + raise RuntimeError( + "P20_RUNTIME_BACKEND=triton requires a CUDA environment with Triton installed" + ) + + +def validate_p20_runtime_backend(runtime_backend: str) -> None: + if runtime_backend not in {"torch", "triton"}: + raise ValueError(f"P20 runtime backend must be torch|triton, got {runtime_backend}") + if runtime_backend == "triton": + ensure_triton_runtime_available() + + +def _next_power_of_two(value: int) -> int: + power = 1 + while power < value: + power <<= 1 + return power + + +if triton_runtime_available(): + + @triton.jit + def _p20_forward_kernel( + update_gate_ptr, + retain_gate_ptr, + transformed_state_ptr, + candidate_ptr, + output_gate_ptr, + next_state_ptr, + emitted_output_ptr, + input_row_stride, + output_row_stride, + width, + BLOCK_SIZE: tl.constexpr, + ): + batch_index = tl.program_id(0) + block_index = tl.program_id(1) + offsets = block_index * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < width + input_row_offsets = batch_index * input_row_stride + offsets + output_row_offsets = batch_index * output_row_stride + offsets + + update_gate = tl.load(update_gate_ptr + input_row_offsets, mask=mask, other=0.0) + retain_gate = tl.load(retain_gate_ptr + input_row_offsets, mask=mask, other=0.0) + transformed_state = tl.load(transformed_state_ptr + input_row_offsets, mask=mask, other=0.0) + candidate = tl.load(candidate_ptr + input_row_offsets, mask=mask, other=0.0) + output_gate = tl.load(output_gate_ptr + input_row_offsets, mask=mask, other=0.0) + + next_state = update_gate * transformed_state + retain_gate * candidate + emitted_output = output_gate * next_state + + tl.store(next_state_ptr + output_row_offsets, next_state, mask=mask) + tl.store(emitted_output_ptr + output_row_offsets, emitted_output, mask=mask) + + + @triton.jit + def _p20_backward_kernel( + grad_next_state_ptr, + grad_emitted_output_ptr, + update_gate_ptr, + retain_gate_ptr, + transformed_state_ptr, + candidate_ptr, + output_gate_ptr, + next_state_ptr, + grad_update_gate_ptr, + grad_retain_gate_ptr, + grad_transformed_state_ptr, + grad_candidate_ptr, + grad_output_gate_ptr, + grad_row_stride, + input_row_stride, + next_state_row_stride, + output_row_stride, + width, + BLOCK_SIZE: tl.constexpr, + ): + batch_index = tl.program_id(0) + block_index = tl.program_id(1) + offsets = block_index * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < width + grad_row_offsets = batch_index * grad_row_stride + offsets + input_row_offsets = batch_index * input_row_stride + offsets + next_state_row_offsets = batch_index * next_state_row_stride + offsets + output_row_offsets = batch_index * output_row_stride + offsets + + grad_next_state = tl.load(grad_next_state_ptr + grad_row_offsets, mask=mask, other=0.0) + grad_emitted_output = tl.load(grad_emitted_output_ptr + grad_row_offsets, mask=mask, other=0.0) + update_gate = tl.load(update_gate_ptr + input_row_offsets, mask=mask, other=0.0) + retain_gate = tl.load(retain_gate_ptr + input_row_offsets, mask=mask, other=0.0) + transformed_state = tl.load(transformed_state_ptr + input_row_offsets, mask=mask, other=0.0) + candidate = tl.load(candidate_ptr + input_row_offsets, mask=mask, other=0.0) + output_gate = tl.load(output_gate_ptr + input_row_offsets, mask=mask, other=0.0) + next_state = tl.load(next_state_ptr + next_state_row_offsets, mask=mask, other=0.0) + + total_grad_next_state = grad_next_state + grad_emitted_output * output_gate + grad_update_gate = total_grad_next_state * transformed_state + grad_retain_gate = total_grad_next_state * candidate + grad_transformed_state = total_grad_next_state * update_gate + grad_candidate = total_grad_next_state * retain_gate + grad_output_gate = grad_emitted_output * next_state + + tl.store(grad_update_gate_ptr + output_row_offsets, grad_update_gate, mask=mask) + tl.store(grad_retain_gate_ptr + output_row_offsets, grad_retain_gate, mask=mask) + tl.store(grad_transformed_state_ptr + output_row_offsets, grad_transformed_state, mask=mask) + tl.store(grad_candidate_ptr + output_row_offsets, grad_candidate, mask=mask) + tl.store(grad_output_gate_ptr + output_row_offsets, grad_output_gate, mask=mask) + + +class _P20FusedUpdateReadout(torch.autograd.Function): + @staticmethod + def forward(ctx, update_gate: Tensor, retain_gate: Tensor, transformed_state: Tensor, candidate: Tensor, output_gate: Tensor): + if not triton_runtime_available(): + raise RuntimeError("P20 Triton fused update requires Triton to be installed") + if update_gate.device.type != "cuda": + raise RuntimeError("P20 Triton fused update requires CUDA tensors") + for tensor in (retain_gate, transformed_state, candidate, output_gate): + if tensor.device != update_gate.device or tensor.shape != update_gate.shape: + raise RuntimeError("P20 Triton fused update requires matching CUDA tensor shapes") + + batch_size, width = update_gate.shape + block_size = min(1024, _next_power_of_two(width)) + input_row_stride = update_gate.stride(0) + next_state = torch.empty_like(update_gate) + emitted_output = torch.empty_like(update_gate) + output_row_stride = next_state.stride(0) + + grid = (batch_size, triton.cdiv(width, block_size)) + _p20_forward_kernel[grid]( + update_gate, + retain_gate, + transformed_state, + candidate, + output_gate, + next_state, + emitted_output, + input_row_stride, + output_row_stride, + width, + BLOCK_SIZE=block_size, + ) + ctx.block_size = block_size + ctx.input_row_stride = input_row_stride + ctx.output_row_stride = output_row_stride + ctx.width = width + ctx.save_for_backward(update_gate, retain_gate, transformed_state, candidate, output_gate, next_state) + return next_state, emitted_output + + @staticmethod + def backward(ctx, grad_next_state: Tensor, grad_emitted_output: Tensor): + update_gate, retain_gate, transformed_state, candidate, output_gate, next_state = ctx.saved_tensors + grad_update_gate = torch.empty_like(update_gate) + grad_retain_gate = torch.empty_like(retain_gate) + grad_transformed_state = torch.empty_like(transformed_state) + grad_candidate = torch.empty_like(candidate) + grad_output_gate = torch.empty_like(output_gate) + grad_next_state = grad_next_state.contiguous() + grad_emitted_output = grad_emitted_output.contiguous() + + grid = (update_gate.shape[0], triton.cdiv(ctx.width, ctx.block_size)) + _p20_backward_kernel[grid]( + grad_next_state, + grad_emitted_output, + update_gate, + retain_gate, + transformed_state, + candidate, + output_gate, + next_state, + grad_update_gate, + grad_retain_gate, + grad_transformed_state, + grad_candidate, + grad_output_gate, + grad_next_state.stride(0), + ctx.input_row_stride, + ctx.output_row_stride, + grad_update_gate.stride(0), + ctx.width, + BLOCK_SIZE=ctx.block_size, + ) + return grad_update_gate, grad_retain_gate, grad_transformed_state, grad_candidate, grad_output_gate + + +if triton_runtime_available(): + + @triton.jit + def _p20_block_diagonal_sequence_forward_kernel( + update_gate_ptr, + angle_cos_ptr, + angle_sin_ptr, + candidate_ptr, + output_gate_ptr, + initial_state_ptr, + transform_weight_ptr, + transform_bias_ptr, + state_history_ptr, + emitted_output_ptr, + tensor_batch_stride, + tensor_seq_stride, + pair_batch_stride, + pair_seq_stride, + state_batch_stride, + history_batch_stride, + history_seq_stride, + output_batch_stride, + output_seq_stride, + weight_block_stride, + weight_row_stride, + weight_col_stride, + block_width, + pair_width, + SEQ_LEN: tl.constexpr, + BLOCK_PAIR_WIDTH: tl.constexpr, + ): + batch_index = tl.program_id(0) + block_index = tl.program_id(1) + + pair_offsets = tl.arange(0, BLOCK_PAIR_WIDTH) + pair_mask = pair_offsets < pair_width + even_offsets = pair_offsets * 2 + odd_offsets = even_offsets + 1 + block_base = block_index * block_width + global_even_offsets = block_base + even_offsets + global_odd_offsets = block_base + odd_offsets + + bias_even = tl.load(transform_bias_ptr + global_even_offsets, mask=pair_mask, other=0.0).to(tl.float32) + bias_odd = tl.load(transform_bias_ptr + global_odd_offsets, mask=pair_mask, other=0.0).to(tl.float32) + + weight_block_ptr = transform_weight_ptr + block_index * weight_block_stride + matrix_mask = pair_mask[:, None] & pair_mask[None, :] + row_even = even_offsets[:, None] + row_odd = odd_offsets[:, None] + col_even = even_offsets[None, :] + col_odd = odd_offsets[None, :] + weight_even_even = tl.load( + weight_block_ptr + row_even * weight_row_stride + col_even * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + weight_even_odd = tl.load( + weight_block_ptr + row_even * weight_row_stride + col_odd * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + weight_odd_even = tl.load( + weight_block_ptr + row_odd * weight_row_stride + col_even * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + weight_odd_odd = tl.load( + weight_block_ptr + row_odd * weight_row_stride + col_odd * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_base = batch_index * state_batch_stride + state_even = tl.load(initial_state_ptr + state_base + global_even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + state_odd = tl.load(initial_state_ptr + state_base + global_odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + + history_batch_base = batch_index * history_batch_stride + tl.store(state_history_ptr + history_batch_base + global_even_offsets, state_even, mask=pair_mask) + tl.store(state_history_ptr + history_batch_base + global_odd_offsets, state_odd, mask=pair_mask) + + tensor_batch_base = batch_index * tensor_batch_stride + pair_batch_base = batch_index * pair_batch_stride + output_batch_base = batch_index * output_batch_stride + + for position in range(SEQ_LEN): + projected_even = ( + tl.sum(weight_even_even * state_even[None, :], axis=1) + + tl.sum(weight_even_odd * state_odd[None, :], axis=1) + + bias_even + ) + projected_odd = ( + tl.sum(weight_odd_even * state_even[None, :], axis=1) + + tl.sum(weight_odd_odd * state_odd[None, :], axis=1) + + bias_odd + ) + + pair_step_base = pair_batch_base + position * pair_seq_stride + block_index * pair_width + cos = tl.load(angle_cos_ptr + pair_step_base + pair_offsets, mask=pair_mask, other=0.0).to(tl.float32) + sin = tl.load(angle_sin_ptr + pair_step_base + pair_offsets, mask=pair_mask, other=0.0).to(tl.float32) + rotated_even = projected_even * cos - projected_odd * sin + rotated_odd = projected_even * sin + projected_odd * cos + + tensor_step_base = tensor_batch_base + position * tensor_seq_stride + block_base + update_even = tl.load(update_gate_ptr + tensor_step_base + even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + update_odd = tl.load(update_gate_ptr + tensor_step_base + odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + retain_even = 1.0 - update_even + retain_odd = 1.0 - update_odd + candidate_even = tl.load(candidate_ptr + tensor_step_base + even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + candidate_odd = tl.load(candidate_ptr + tensor_step_base + odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + output_even = tl.load(output_gate_ptr + tensor_step_base + even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + output_odd = tl.load(output_gate_ptr + tensor_step_base + odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + + next_even = update_even * rotated_even + retain_even * candidate_even + next_odd = update_odd * rotated_odd + retain_odd * candidate_odd + emitted_even = output_even * next_even + emitted_odd = output_odd * next_odd + + output_step_base = output_batch_base + position * output_seq_stride + block_base + tl.store(emitted_output_ptr + output_step_base + even_offsets, emitted_even, mask=pair_mask) + tl.store(emitted_output_ptr + output_step_base + odd_offsets, emitted_odd, mask=pair_mask) + + history_step_base = history_batch_base + (position + 1) * history_seq_stride + block_base + tl.store(state_history_ptr + history_step_base + even_offsets, next_even, mask=pair_mask) + tl.store(state_history_ptr + history_step_base + odd_offsets, next_odd, mask=pair_mask) + + state_even = next_even + state_odd = next_odd + + + @triton.jit + def _p20_block_diagonal_sequence_backward_kernel( + grad_emitted_output_ptr, + grad_final_state_ptr, + update_gate_ptr, + angle_cos_ptr, + angle_sin_ptr, + candidate_ptr, + output_gate_ptr, + state_history_ptr, + transform_weight_ptr, + transform_bias_ptr, + grad_update_gate_ptr, + grad_angle_cos_ptr, + grad_angle_sin_ptr, + grad_candidate_ptr, + grad_output_gate_ptr, + grad_initial_state_ptr, + grad_transform_weight_ptr, + grad_transform_bias_ptr, + tensor_batch_stride, + tensor_seq_stride, + pair_batch_stride, + pair_seq_stride, + state_batch_stride, + history_batch_stride, + history_seq_stride, + weight_block_stride, + weight_row_stride, + weight_col_stride, + grad_weight_batch_stride, + grad_weight_block_stride, + grad_weight_row_stride, + grad_weight_col_stride, + grad_bias_batch_stride, + block_width, + pair_width, + SEQ_LEN: tl.constexpr, + BLOCK_PAIR_WIDTH: tl.constexpr, + ): + batch_index = tl.program_id(0) + block_index = tl.program_id(1) + + pair_offsets = tl.arange(0, BLOCK_PAIR_WIDTH) + pair_mask = pair_offsets < pair_width + even_offsets = pair_offsets * 2 + odd_offsets = even_offsets + 1 + block_base = block_index * block_width + global_even_offsets = block_base + even_offsets + global_odd_offsets = block_base + odd_offsets + + weight_block_ptr = transform_weight_ptr + block_index * weight_block_stride + matrix_mask = pair_mask[:, None] & pair_mask[None, :] + row_even = even_offsets[:, None] + row_odd = odd_offsets[:, None] + col_even = even_offsets[None, :] + col_odd = odd_offsets[None, :] + weight_even_even = tl.load( + weight_block_ptr + row_even * weight_row_stride + col_even * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + weight_even_odd = tl.load( + weight_block_ptr + row_even * weight_row_stride + col_odd * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + weight_odd_even = tl.load( + weight_block_ptr + row_odd * weight_row_stride + col_even * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + weight_odd_odd = tl.load( + weight_block_ptr + row_odd * weight_row_stride + col_odd * weight_col_stride, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + bias_even = tl.load(transform_bias_ptr + global_even_offsets, mask=pair_mask, other=0.0).to(tl.float32) + bias_odd = tl.load(transform_bias_ptr + global_odd_offsets, mask=pair_mask, other=0.0).to(tl.float32) + + grad_state_even = tl.load( + grad_final_state_ptr + batch_index * state_batch_stride + global_even_offsets, + mask=pair_mask, + other=0.0, + ).to(tl.float32) + grad_state_odd = tl.load( + grad_final_state_ptr + batch_index * state_batch_stride + global_odd_offsets, + mask=pair_mask, + other=0.0, + ).to(tl.float32) + + grad_weight_even_even = tl.zeros((BLOCK_PAIR_WIDTH, BLOCK_PAIR_WIDTH), dtype=tl.float32) + grad_weight_even_odd = tl.zeros((BLOCK_PAIR_WIDTH, BLOCK_PAIR_WIDTH), dtype=tl.float32) + grad_weight_odd_even = tl.zeros((BLOCK_PAIR_WIDTH, BLOCK_PAIR_WIDTH), dtype=tl.float32) + grad_weight_odd_odd = tl.zeros((BLOCK_PAIR_WIDTH, BLOCK_PAIR_WIDTH), dtype=tl.float32) + grad_bias_even = tl.zeros((BLOCK_PAIR_WIDTH,), dtype=tl.float32) + grad_bias_odd = tl.zeros((BLOCK_PAIR_WIDTH,), dtype=tl.float32) + + tensor_batch_base = batch_index * tensor_batch_stride + pair_batch_base = batch_index * pair_batch_stride + history_batch_base = batch_index * history_batch_stride + + for position in range(SEQ_LEN - 1, -1, -1): + state_history_base = history_batch_base + position * history_seq_stride + block_base + next_state_history_base = history_batch_base + (position + 1) * history_seq_stride + block_base + state_even = tl.load(state_history_ptr + state_history_base + even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + state_odd = tl.load(state_history_ptr + state_history_base + odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + next_even = tl.load( + state_history_ptr + next_state_history_base + even_offsets, + mask=pair_mask, + other=0.0, + ).to(tl.float32) + next_odd = tl.load( + state_history_ptr + next_state_history_base + odd_offsets, + mask=pair_mask, + other=0.0, + ).to(tl.float32) + + projected_even = ( + tl.sum(weight_even_even * state_even[None, :], axis=1) + + tl.sum(weight_even_odd * state_odd[None, :], axis=1) + + bias_even + ) + projected_odd = ( + tl.sum(weight_odd_even * state_even[None, :], axis=1) + + tl.sum(weight_odd_odd * state_odd[None, :], axis=1) + + bias_odd + ) + + pair_step_base = pair_batch_base + position * pair_seq_stride + block_index * pair_width + cos = tl.load(angle_cos_ptr + pair_step_base + pair_offsets, mask=pair_mask, other=0.0).to(tl.float32) + sin = tl.load(angle_sin_ptr + pair_step_base + pair_offsets, mask=pair_mask, other=0.0).to(tl.float32) + rotated_even = projected_even * cos - projected_odd * sin + rotated_odd = projected_even * sin + projected_odd * cos + + tensor_step_base = tensor_batch_base + position * tensor_seq_stride + block_base + update_even = tl.load(update_gate_ptr + tensor_step_base + even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + update_odd = tl.load(update_gate_ptr + tensor_step_base + odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + retain_even = 1.0 - update_even + retain_odd = 1.0 - update_odd + candidate_even = tl.load(candidate_ptr + tensor_step_base + even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + candidate_odd = tl.load(candidate_ptr + tensor_step_base + odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + output_even = tl.load(output_gate_ptr + tensor_step_base + even_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + output_odd = tl.load(output_gate_ptr + tensor_step_base + odd_offsets, mask=pair_mask, other=0.0).to( + tl.float32 + ) + + grad_output_base = tensor_batch_base + position * tensor_seq_stride + block_base + grad_emitted_even = tl.load( + grad_emitted_output_ptr + grad_output_base + even_offsets, + mask=pair_mask, + other=0.0, + ).to(tl.float32) + grad_emitted_odd = tl.load( + grad_emitted_output_ptr + grad_output_base + odd_offsets, + mask=pair_mask, + other=0.0, + ).to(tl.float32) + + total_grad_next_even = grad_state_even + grad_emitted_even * output_even + total_grad_next_odd = grad_state_odd + grad_emitted_odd * output_odd + grad_output_even_step = grad_emitted_even * next_even + grad_output_odd_step = grad_emitted_odd * next_odd + + grad_update_even = total_grad_next_even * rotated_even + grad_update_odd = total_grad_next_odd * rotated_odd + grad_retain_even = total_grad_next_even * candidate_even + grad_retain_odd = total_grad_next_odd * candidate_odd + grad_candidate_even = total_grad_next_even * retain_even + grad_candidate_odd = total_grad_next_odd * retain_odd + + grad_rotated_even = total_grad_next_even * update_even + grad_rotated_odd = total_grad_next_odd * update_odd + grad_projected_even = grad_rotated_even * cos + grad_rotated_odd * sin + grad_projected_odd = -grad_rotated_even * sin + grad_rotated_odd * cos + grad_cos = grad_rotated_even * projected_even + grad_rotated_odd * projected_odd + grad_sin = -grad_rotated_even * projected_odd + grad_rotated_odd * projected_even + + tl.store( + grad_update_gate_ptr + grad_output_base + even_offsets, + grad_update_even - grad_retain_even, + mask=pair_mask, + ) + tl.store( + grad_update_gate_ptr + grad_output_base + odd_offsets, + grad_update_odd - grad_retain_odd, + mask=pair_mask, + ) + tl.store(grad_candidate_ptr + grad_output_base + even_offsets, grad_candidate_even, mask=pair_mask) + tl.store(grad_candidate_ptr + grad_output_base + odd_offsets, grad_candidate_odd, mask=pair_mask) + tl.store(grad_output_gate_ptr + grad_output_base + even_offsets, grad_output_even_step, mask=pair_mask) + tl.store(grad_output_gate_ptr + grad_output_base + odd_offsets, grad_output_odd_step, mask=pair_mask) + tl.store(grad_angle_cos_ptr + pair_step_base + pair_offsets, grad_cos, mask=pair_mask) + tl.store(grad_angle_sin_ptr + pair_step_base + pair_offsets, grad_sin, mask=pair_mask) + + grad_weight_even_even += grad_projected_even[:, None] * state_even[None, :] + grad_weight_even_odd += grad_projected_even[:, None] * state_odd[None, :] + grad_weight_odd_even += grad_projected_odd[:, None] * state_even[None, :] + grad_weight_odd_odd += grad_projected_odd[:, None] * state_odd[None, :] + grad_bias_even += grad_projected_even + grad_bias_odd += grad_projected_odd + + grad_state_even = ( + tl.sum(weight_even_even * grad_projected_even[:, None], axis=0) + + tl.sum(weight_odd_even * grad_projected_odd[:, None], axis=0) + ) + grad_state_odd = ( + tl.sum(weight_even_odd * grad_projected_even[:, None], axis=0) + + tl.sum(weight_odd_odd * grad_projected_odd[:, None], axis=0) + ) + + grad_initial_base = batch_index * state_batch_stride + block_base + tl.store(grad_initial_state_ptr + grad_initial_base + even_offsets, grad_state_even, mask=pair_mask) + tl.store(grad_initial_state_ptr + grad_initial_base + odd_offsets, grad_state_odd, mask=pair_mask) + + grad_bias_base = batch_index * grad_bias_batch_stride + block_base + tl.store(grad_transform_bias_ptr + grad_bias_base + even_offsets, grad_bias_even, mask=pair_mask) + tl.store(grad_transform_bias_ptr + grad_bias_base + odd_offsets, grad_bias_odd, mask=pair_mask) + + grad_weight_block_ptr = ( + grad_transform_weight_ptr + + batch_index * grad_weight_batch_stride + + block_index * grad_weight_block_stride + ) + tl.store( + grad_weight_block_ptr + row_even * grad_weight_row_stride + col_even * grad_weight_col_stride, + grad_weight_even_even, + mask=matrix_mask, + ) + tl.store( + grad_weight_block_ptr + row_even * grad_weight_row_stride + col_odd * grad_weight_col_stride, + grad_weight_even_odd, + mask=matrix_mask, + ) + tl.store( + grad_weight_block_ptr + row_odd * grad_weight_row_stride + col_even * grad_weight_col_stride, + grad_weight_odd_even, + mask=matrix_mask, + ) + tl.store( + grad_weight_block_ptr + row_odd * grad_weight_row_stride + col_odd * grad_weight_col_stride, + grad_weight_odd_odd, + mask=matrix_mask, + ) + + +class _P20BlockDiagonalSequenceScan(torch.autograd.Function): + @staticmethod + def forward( + ctx, + update_gate: Tensor, + angle_cos: Tensor, + angle_sin: Tensor, + candidate: Tensor, + output_gate: Tensor, + initial_state: Tensor, + transform_weight: Tensor, + transform_bias: Tensor, + ) -> tuple[Tensor, Tensor]: + ensure_triton_runtime_available() + tensors = ( + update_gate, + angle_cos, + angle_sin, + candidate, + output_gate, + initial_state, + transform_weight, + transform_bias, + ) + if any(tensor.device.type != "cuda" for tensor in tensors): + raise RuntimeError("P20 Triton sequence scan requires CUDA tensors") + + update_gate = update_gate.contiguous() + angle_cos = angle_cos.contiguous() + angle_sin = angle_sin.contiguous() + candidate = candidate.contiguous() + output_gate = output_gate.contiguous() + initial_state = initial_state.contiguous() + transform_weight = transform_weight.contiguous() + transform_bias = transform_bias.contiguous() + + batch_size, seq_len, width = update_gate.shape + if candidate.shape != update_gate.shape or output_gate.shape != update_gate.shape: + raise RuntimeError("P20 Triton sequence scan requires matching [batch, seq, width] tensor shapes") + if initial_state.shape != (batch_size, width): + raise RuntimeError("P20 Triton sequence scan requires initial_state shape [batch, width]") + if angle_cos.shape != angle_sin.shape: + raise RuntimeError("P20 Triton sequence scan requires matching cosine and sine runtime plans") + if angle_cos.shape[:2] != update_gate.shape[:2]: + raise RuntimeError("P20 Triton sequence scan requires angle tensors aligned with [batch, seq]") + + blocks, block_width, block_width_cols = transform_weight.shape + if block_width != block_width_cols: + raise RuntimeError("P20 Triton sequence scan requires square block-diagonal transform weights") + if blocks * block_width != width: + raise RuntimeError("P20 Triton sequence scan requires block-diagonal weight width to match state width") + if block_width % 2 != 0: + raise RuntimeError("P20 Triton sequence scan requires an even per-block width for rotary pairs") + pair_width = block_width // 2 + expected_angle_width = width // 2 + if angle_cos.shape[2] != expected_angle_width: + raise RuntimeError("P20 Triton sequence scan requires angle width equal to state_width // 2") + block_pair_width = _next_power_of_two(pair_width) + + emitted_outputs = torch.empty_like(update_gate) + state_history = torch.empty( + batch_size, + seq_len + 1, + width, + device=update_gate.device, + dtype=update_gate.dtype, + ) + + grid = (batch_size, blocks) + _p20_block_diagonal_sequence_forward_kernel[grid]( + update_gate, + angle_cos, + angle_sin, + candidate, + output_gate, + initial_state, + transform_weight, + transform_bias, + state_history, + emitted_outputs, + update_gate.stride(0), + update_gate.stride(1), + angle_cos.stride(0), + angle_cos.stride(1), + initial_state.stride(0), + state_history.stride(0), + state_history.stride(1), + emitted_outputs.stride(0), + emitted_outputs.stride(1), + transform_weight.stride(0), + transform_weight.stride(1), + transform_weight.stride(2), + block_width, + pair_width, + SEQ_LEN=seq_len, + BLOCK_PAIR_WIDTH=block_pair_width, + ) + + ctx.seq_len = seq_len + ctx.block_pair_width = block_pair_width + ctx.pair_width = pair_width + ctx.save_for_backward( + update_gate, + angle_cos, + angle_sin, + candidate, + output_gate, + initial_state, + transform_weight, + transform_bias, + state_history, + ) + return emitted_outputs, state_history[:, -1, :].contiguous() + + @staticmethod + def backward(ctx, grad_emitted_outputs: Tensor | None, grad_final_state: Tensor | None): + ( + update_gate, + angle_cos, + angle_sin, + candidate, + output_gate, + initial_state, + transform_weight, + transform_bias, + state_history, + ) = ctx.saved_tensors + if grad_emitted_outputs is None: + grad_emitted_outputs = torch.zeros_like(update_gate) + else: + grad_emitted_outputs = grad_emitted_outputs.contiguous() + if grad_final_state is None: + grad_final_state = torch.zeros_like(initial_state) + else: + grad_final_state = grad_final_state.contiguous() + + batch_size, _seq_len, width = update_gate.shape + blocks, block_width, _block_width_cols = transform_weight.shape + + grad_update_gate = torch.empty_like(update_gate) + grad_angle_cos = torch.empty_like(angle_cos) + grad_angle_sin = torch.empty_like(angle_sin) + grad_candidate = torch.empty_like(candidate) + grad_output_gate = torch.empty_like(output_gate) + grad_initial_state = torch.empty_like(initial_state) + grad_transform_weight_contrib = torch.empty( + batch_size, + blocks, + block_width, + block_width, + device=transform_weight.device, + dtype=transform_weight.dtype, + ) + grad_transform_bias_contrib = torch.empty( + batch_size, + width, + device=transform_bias.device, + dtype=transform_bias.dtype, + ) + + grid = (batch_size, blocks) + _p20_block_diagonal_sequence_backward_kernel[grid]( + grad_emitted_outputs, + grad_final_state, + update_gate, + angle_cos, + angle_sin, + candidate, + output_gate, + state_history, + transform_weight, + transform_bias, + grad_update_gate, + grad_angle_cos, + grad_angle_sin, + grad_candidate, + grad_output_gate, + grad_initial_state, + grad_transform_weight_contrib, + grad_transform_bias_contrib, + update_gate.stride(0), + update_gate.stride(1), + angle_cos.stride(0), + angle_cos.stride(1), + initial_state.stride(0), + state_history.stride(0), + state_history.stride(1), + transform_weight.stride(0), + transform_weight.stride(1), + transform_weight.stride(2), + grad_transform_weight_contrib.stride(0), + grad_transform_weight_contrib.stride(1), + grad_transform_weight_contrib.stride(2), + grad_transform_weight_contrib.stride(3), + grad_transform_bias_contrib.stride(0), + block_width, + ctx.pair_width, + SEQ_LEN=ctx.seq_len, + BLOCK_PAIR_WIDTH=ctx.block_pair_width, + ) + grad_transform_weight = grad_transform_weight_contrib.sum(dim=0) + grad_transform_bias = grad_transform_bias_contrib.sum(dim=0) + return ( + grad_update_gate, + grad_angle_cos, + grad_angle_sin, + grad_candidate, + grad_output_gate, + grad_initial_state, + grad_transform_weight, + grad_transform_bias, + ) + + +def _rotate_state_pairs_with_trig(state: Tensor, *, cos: Tensor, sin: Tensor) -> Tensor: + batch_size, width = state.shape + pair_count = width // 2 + state_pairs = state.reshape(batch_size, pair_count, 2) + first = state_pairs[..., 0] + second = state_pairs[..., 1] + rotated_first = first * cos - second * sin + rotated_second = first * sin + second * cos + return torch.stack((rotated_first, rotated_second), dim=-1).reshape(batch_size, width) + + +class P20RuntimeSequenceMixer(nn.Module): + def __init__( + self, + d_model: int, + *, + state_transform_blocks: int = 2, + runtime_backend: str = "triton", + scan_chunk_size: int = 0, + ) -> None: + super().__init__() + if d_model % 2 != 0: + raise ValueError(f"P20 requires even d_model, got {d_model}") + validate_p20_runtime_backend(runtime_backend) + if scan_chunk_size < 0: + raise ValueError(f"P20 scan_chunk_size must be non-negative, got {scan_chunk_size}") + self.d_model = d_model + self.runtime_backend = runtime_backend + self.scan_chunk_size = scan_chunk_size + self.in_projection = PackedLinearProjection(d_model, (d_model, d_model // 2, d_model, d_model)) + self.state_transform_projection = BlockDiagonalLinear(d_model, blocks=state_transform_blocks) + + def forward(self, inputs: Tensor) -> Tensor: + update_gate_inputs, angle_inputs, candidate_inputs, output_gate_inputs = self.in_projection(inputs) + update_gates = torch.sigmoid(update_gate_inputs) + angle_cos = torch.cos(angle_inputs) + angle_sin = torch.sin(angle_inputs) + candidates = torch.tanh(candidate_inputs) + output_gates = torch.sigmoid(output_gate_inputs) + + batch_size, seq_len, _ = inputs.shape + state = torch.zeros(batch_size, self.d_model, device=inputs.device, dtype=inputs.dtype) + if self.runtime_backend == "triton": + if not inputs.is_cuda: + raise RuntimeError("P20 Triton sequence scan requires CUDA inputs") + transform_weight = self.state_transform_projection.weight.reshape( + self.state_transform_projection.blocks, + self.state_transform_projection.block_width, + self.state_transform_projection.block_width, + ) + scan_batch_size = batch_size + scan_seq_len = seq_len + if self.scan_chunk_size > 0 and seq_len > self.scan_chunk_size: + chunk_size = self.scan_chunk_size + if seq_len % chunk_size != 0: + raise RuntimeError( + f"P20_SCAN_CHUNK_SIZE={chunk_size} must evenly divide sequence length {seq_len}" + ) + chunk_count = seq_len // chunk_size + + def chunk_tensor(tensor: Tensor) -> Tensor: + return tensor.reshape(batch_size, chunk_count, chunk_size, *tensor.shape[2:]).reshape( + batch_size * chunk_count, + chunk_size, + *tensor.shape[2:], + ) + + update_gates = chunk_tensor(update_gates) + angle_cos = chunk_tensor(angle_cos) + angle_sin = chunk_tensor(angle_sin) + candidates = chunk_tensor(candidates) + output_gates = chunk_tensor(output_gates) + scan_batch_size = batch_size * chunk_count + scan_seq_len = chunk_size + state = torch.zeros(scan_batch_size, self.d_model, device=inputs.device, dtype=inputs.dtype) + outputs, _final_state = _P20BlockDiagonalSequenceScan.apply( + update_gates, + angle_cos, + angle_sin, + candidates, + output_gates, + state, + transform_weight, + self.state_transform_projection.bias, + ) + if self.scan_chunk_size > 0 and seq_len > self.scan_chunk_size: + outputs = outputs.reshape(batch_size, seq_len // scan_seq_len, scan_seq_len, self.d_model).reshape( + batch_size, + seq_len, + self.d_model, + ) + return outputs + + retain_gates = 1.0 - update_gates + outputs = torch.empty(batch_size, seq_len, self.d_model, device=inputs.device, dtype=inputs.dtype) + for position in range(seq_len): + projected_state = self.state_transform_projection(state) + transformed_state = _rotate_state_pairs_with_trig( + projected_state, + cos=angle_cos[:, position, :], + sin=angle_sin[:, position, :], + ) + state = update_gates[:, position, :] * transformed_state + retain_gates[:, position, :] * candidates[:, position, :] + outputs[:, position, :] = output_gates[:, position, :] * state + return outputs diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/p20_ttt.py b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/p20_ttt.py new file mode 100644 index 0000000000..873646cc8e --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/p20_ttt.py @@ -0,0 +1,782 @@ +from __future__ import annotations + +import math +import random +from dataclasses import dataclass +from typing import Iterable + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn + + +P20_TTT_MODES = frozenset({"off", "residual_scalar", "residual_vector"}) + + +@dataclass(frozen=True) +class P20TTTStats: + doc_count: int + chunk_count: int + update_count: int + token_count: float + byte_count: float + + +@dataclass(frozen=True) +class P20TTTPairedResult: + baseline_val_loss: float + baseline_val_bpb: float + ttt_val_loss: float + ttt_val_bpb: float + baseline_stats: P20TTTStats + ttt_stats: P20TTTStats + paired_stats: dict[str, float | int | None] + per_doc: list[dict[str, float | int]] + + +def normalize_p20_ttt_mode(raw_mode: str) -> str: + mode = raw_mode.strip().lower() + if mode not in P20_TTT_MODES: + raise ValueError(f"P20_TTT_MODE must be one of {sorted(P20_TTT_MODES)}, got {raw_mode!r}") + return mode + + +def _iter_p20_blocks(model: nn.Module) -> Iterable[nn.Module]: + blocks = getattr(model, "blocks", None) + if blocks is None: + return + for block in blocks: + if ( + hasattr(block, "primitive") + and hasattr(block, "readout_projection") + and hasattr(block, "residual_scale") + ): + yield block + + +def _doc_ranges_from_bos(val_tokens: Tensor, bos_token_id: int) -> list[tuple[int, int]]: + if val_tokens.device.type != "cpu": + val_tokens = val_tokens.cpu() + starts = (val_tokens == bos_token_id).nonzero(as_tuple=False).flatten().tolist() + if not starts: + raise ValueError(f"no BOS token id {bos_token_id} found in validation token stream") + if starts[0] != 0: + starts.insert(0, 0) + ends = starts[1:] + [int(val_tokens.numel())] + return [(start, end) for start, end in zip(starts, ends, strict=True) if end - start >= 2] + + +def _reset_adapter_params(params: list[Tensor], optimizer: torch.optim.Optimizer) -> None: + optimizer.zero_grad(set_to_none=True) + with torch.no_grad(): + for param in params: + param.zero_() + optimizer.state.clear() + + +def _set_p20_ttt_gate_params(model: nn.Module, mode: str, device: torch.device) -> tuple[list[nn.Module], list[Tensor]]: + blocks = list(_iter_p20_blocks(model)) + if not blocks: + raise ValueError("P20_TTT_MODE requires at least one P20 primitive block") + + params: list[Tensor] = [] + for block in blocks: + residual_scale = getattr(block, "residual_scale") + if mode == "residual_scalar": + shape = (1,) + elif mode == "residual_vector": + shape = tuple(residual_scale.shape) + else: + raise ValueError(f"unsupported P20 TTT mode {mode!r}") + gate = torch.zeros(shape, device=device, dtype=torch.float32, requires_grad=True) + setattr(block, "p20_ttt_residual_gate", gate) + params.append(gate) + return blocks, params + + +def _clear_p20_ttt_gate_params(blocks: list[nn.Module]) -> None: + for block in blocks: + setattr(block, "p20_ttt_residual_gate", None) + + +def _restore_requires_grad(params: list[nn.Parameter], flags: list[bool]) -> None: + for param, flag in zip(params, flags, strict=True): + param.requires_grad_(flag) + + +def _next_chunk_end(position: int, target_tokens: int, requested_chunk_size: int, scan_chunk_size: int) -> int: + remaining = target_tokens - position + if remaining <= requested_chunk_size: + if scan_chunk_size > 0 and remaining > scan_chunk_size: + rounded = (remaining // scan_chunk_size) * scan_chunk_size + if rounded > 0: + return position + rounded + return position + remaining + + chunk = requested_chunk_size + if scan_chunk_size > 0 and chunk > scan_chunk_size: + chunk = (chunk // scan_chunk_size) * scan_chunk_size + return position + max(chunk, 1) + + +def _resolve_context_size(args) -> int: + context_size = int(getattr(args, "p20_ttt_context_size", 0) or 0) + if context_size <= 0: + context_size = int(args.train_seq_len) + if context_size < args.p20_ttt_chunk_size: + raise ValueError( + "P20_TTT_CONTEXT_SIZE must be 0 or at least P20_TTT_CHUNK_SIZE; " + f"got context={context_size} chunk={args.p20_ttt_chunk_size}" + ) + return context_size + + +def _chunk_tensors( + val_tokens: Tensor, + doc_start_token: int, + position: int, + next_position: int, + context_size: int, + device: torch.device, +) -> tuple[Tensor, Tensor, Tensor, int]: + context_start = max(0, next_position - context_size) + if context_start > position: + raise RuntimeError( + f"context start {context_start} passed score position {position}; " + "increase P20_TTT_CONTEXT_SIZE or reduce P20_TTT_CHUNK_SIZE" + ) + x_cpu = val_tokens[doc_start_token + context_start : doc_start_token + next_position] + y_cpu = val_tokens[doc_start_token + position + 1 : doc_start_token + next_position + 1] + prev_cpu = val_tokens[doc_start_token + position : doc_start_token + next_position] + x = x_cpu.to(device=device, dtype=torch.int64, non_blocking=True).unsqueeze(0) + y = y_cpu.to(device=device, dtype=torch.int64, non_blocking=True).unsqueeze(0) + prev = prev_cpu.to(device=device, dtype=torch.int64, non_blocking=True) + score_offset = position - context_start + return x, y, prev, score_offset + + +def _token_byte_count( + previous_ids: Tensor, + target_ids: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> Tensor: + token_bytes = base_bytes_lut[target_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[target_ids] & ~is_boundary_token_lut[previous_ids]).to(dtype=torch.int16) + return token_bytes.to(torch.float64).sum() + + +def _bpb_from_loss_and_counts(loss_sum: float, token_count: float, byte_count: float) -> float: + if token_count <= 0 or byte_count <= 0: + return float("nan") + return (loss_sum / token_count) / math.log(2.0) * (token_count / byte_count) + + +def _summarize_paired_records( + records: list[dict[str, float | int]], + *, + bootstrap_samples: int, + bootstrap_seed: int, +) -> dict[str, float | int | None]: + if not records: + return { + "doc_count": 0, + "delta_bpb": None, + "mean_doc_delta_bpb": None, + "median_doc_delta_bpb": None, + "stdev_doc_delta_bpb": None, + "sem_doc_delta_bpb": None, + "bootstrap_samples": 0, + "bootstrap_ci95_low_bpb": None, + "bootstrap_ci95_high_bpb": None, + "bootstrap_p_two_sided": None, + "improved_doc_count": 0, + "worsened_doc_count": 0, + "unchanged_doc_count": 0, + } + + n_docs = len(records) + baseline_loss_sum = sum(float(record["baseline_loss_sum"]) for record in records) + ttt_loss_sum = sum(float(record["ttt_loss_sum"]) for record in records) + token_count = sum(float(record["token_count"]) for record in records) + byte_count = sum(float(record["byte_count"]) for record in records) + baseline_bpb = _bpb_from_loss_and_counts(baseline_loss_sum, token_count, byte_count) + ttt_bpb = _bpb_from_loss_and_counts(ttt_loss_sum, token_count, byte_count) + deltas = [float(record["delta_bpb"]) for record in records] + mean_delta = sum(deltas) / n_docs + median_delta = sorted(deltas)[n_docs // 2] if n_docs % 2 else 0.5 * ( + sorted(deltas)[n_docs // 2 - 1] + sorted(deltas)[n_docs // 2] + ) + if n_docs > 1: + variance = sum((delta - mean_delta) ** 2 for delta in deltas) / (n_docs - 1) + stdev_delta = variance**0.5 + sem_delta = stdev_delta / (n_docs**0.5) + else: + stdev_delta = 0.0 + sem_delta = 0.0 + + bootstrap_values: list[float] = [] + requested_samples = max(0, int(bootstrap_samples)) + if requested_samples > 0: + rng = random.Random(bootstrap_seed) + for _ in range(requested_samples): + sample_baseline_loss = 0.0 + sample_ttt_loss = 0.0 + sample_tokens = 0.0 + sample_bytes = 0.0 + for _doc in range(n_docs): + record = records[rng.randrange(n_docs)] + sample_baseline_loss += float(record["baseline_loss_sum"]) + sample_ttt_loss += float(record["ttt_loss_sum"]) + sample_tokens += float(record["token_count"]) + sample_bytes += float(record["byte_count"]) + sample_baseline_bpb = _bpb_from_loss_and_counts(sample_baseline_loss, sample_tokens, sample_bytes) + sample_ttt_bpb = _bpb_from_loss_and_counts(sample_ttt_loss, sample_tokens, sample_bytes) + bootstrap_values.append(sample_ttt_bpb - sample_baseline_bpb) + + if bootstrap_values: + ordered = sorted(bootstrap_values) + low_index = min(len(ordered) - 1, max(0, int(0.025 * len(ordered)))) + high_index = min(len(ordered) - 1, max(0, int(0.975 * len(ordered)))) + ci_low = ordered[low_index] + ci_high = ordered[high_index] + p_ge_zero = (sum(1 for value in bootstrap_values if value >= 0.0) + 1) / (len(bootstrap_values) + 1) + p_le_zero = (sum(1 for value in bootstrap_values if value <= 0.0) + 1) / (len(bootstrap_values) + 1) + p_two_sided = min(1.0, 2.0 * min(p_ge_zero, p_le_zero)) + else: + ci_low = None + ci_high = None + p_two_sided = None + + improved = sum(1 for delta in deltas if delta < 0.0) + worsened = sum(1 for delta in deltas if delta > 0.0) + unchanged = n_docs - improved - worsened + return { + "doc_count": n_docs, + "baseline_bpb": baseline_bpb, + "ttt_bpb": ttt_bpb, + "delta_bpb": ttt_bpb - baseline_bpb, + "mean_doc_delta_bpb": mean_delta, + "median_doc_delta_bpb": median_delta, + "stdev_doc_delta_bpb": stdev_delta, + "sem_doc_delta_bpb": sem_delta, + "bootstrap_samples": len(bootstrap_values), + "bootstrap_ci95_low_bpb": ci_low, + "bootstrap_ci95_high_bpb": ci_high, + "bootstrap_p_two_sided": p_two_sided, + "improved_doc_count": improved, + "worsened_doc_count": worsened, + "unchanged_doc_count": unchanged, + } + + +def _score_chunk( + model: nn.Module, + x: Tensor, + y: Tensor, + score_offset: int, + *, + device_type: str, + autocast_enabled: bool, + reduction: str, +) -> Tensor: + with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled): + logits = model.forward_logits(x) + scored_logits = logits[:, score_offset : score_offset + y.size(1), :] + return F.cross_entropy( + scored_logits.reshape(-1, scored_logits.size(-1)).float(), + y.reshape(-1), + reduction=reduction, + ) + + +def eval_val_doc_chunks( + args, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + *, + bos_token_id: int, +) -> tuple[float, float, P20TTTStats]: + if args.p20_ttt_chunk_size <= 0: + raise ValueError(f"P20_TTT_CHUNK_SIZE must be positive, got {args.p20_ttt_chunk_size}") + + doc_ranges = _doc_ranges_from_bos(val_tokens, bos_token_id) + if args.p20_ttt_max_docs > 0: + doc_ranges = doc_ranges[: args.p20_ttt_max_docs] + doc_start = (len(doc_ranges) * rank) // world_size + doc_end = (len(doc_ranges) * (rank + 1)) // world_size + local_doc_ranges = doc_ranges[doc_start:doc_end] + + was_training = model.training + model.eval() + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + chunk_count = 0 + processed_docs = 0 + + context_size = _resolve_context_size(args) + autocast_enabled = device.type == "cuda" + scan_chunk_size = int(getattr(args, "p20_scan_chunk_size", 0) or 0) + with torch.no_grad(): + for doc_start_token, doc_end_token in local_doc_ranges: + target_tokens = doc_end_token - doc_start_token - 1 + if target_tokens <= 0: + continue + processed_docs += 1 + position = 0 + while position < target_tokens: + next_position = _next_chunk_end( + position, + target_tokens, + args.p20_ttt_chunk_size, + scan_chunk_size, + ) + x, y, previous_ids, score_offset = _chunk_tensors( + val_tokens, + doc_start_token, + position, + next_position, + context_size, + device, + ) + score_loss = _score_chunk( + model, + x, + y, + score_offset, + device_type=device.type, + autocast_enabled=autocast_enabled, + reduction="sum", + ) + val_loss_sum += score_loss.to(torch.float64) + val_token_count += float(y.numel()) + val_byte_count += _token_byte_count( + previous_ids, + y.reshape(-1), + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + chunk_count += 1 + position = next_position + + model.train(was_training) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + counts = torch.tensor([processed_docs, chunk_count], device=device, dtype=torch.float64) + dist.all_reduce(counts, op=dist.ReduceOp.SUM) + processed_docs = int(counts[0].item()) + chunk_count = int(counts[1].item()) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + stats = P20TTTStats( + doc_count=processed_docs, + chunk_count=chunk_count, + update_count=0, + token_count=float(val_token_count.item()), + byte_count=float(val_byte_count.item()), + ) + return float(val_loss.item()), float(bits_per_token * tokens_per_byte), stats + + +def eval_val_p20_gate_ttt_paired( + args, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + *, + bos_token_id: int, + bootstrap_samples: int, + bootstrap_seed: int, +) -> P20TTTPairedResult: + mode = normalize_p20_ttt_mode(args.p20_ttt_mode) + if mode == "off": + raise ValueError("eval_val_p20_gate_ttt_paired called with P20_TTT_MODE=off") + if args.p20_ttt_chunk_size <= 0: + raise ValueError(f"P20_TTT_CHUNK_SIZE must be positive, got {args.p20_ttt_chunk_size}") + + doc_ranges = _doc_ranges_from_bos(val_tokens, bos_token_id) + if args.p20_ttt_max_docs > 0: + doc_ranges = doc_ranges[: args.p20_ttt_max_docs] + doc_start = (len(doc_ranges) * rank) // world_size + doc_end = (len(doc_ranges) * (rank + 1)) // world_size + local_doc_ranges = doc_ranges[doc_start:doc_end] + + was_training = model.training + model.eval() + base_params = list(model.parameters()) + base_requires_grad = [param.requires_grad for param in base_params] + for param in base_params: + param.requires_grad_(False) + + blocks: list[nn.Module] = [] + adapter_params: list[Tensor] = [] + records: list[dict[str, float | int]] = [] + baseline_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + ttt_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_count_sum = torch.zeros((), device=device, dtype=torch.float64) + chunk_count = 0 + update_count = 0 + processed_docs = 0 + + try: + blocks, adapter_params = _set_p20_ttt_gate_params(model, mode, device) + optimizer = torch.optim.Adam( + adapter_params, + lr=args.p20_ttt_lr, + betas=(args.p20_ttt_beta1, args.p20_ttt_beta2), + weight_decay=args.p20_ttt_weight_decay, + ) + context_size = _resolve_context_size(args) + autocast_enabled = device.type == "cuda" + scan_chunk_size = int(getattr(args, "p20_scan_chunk_size", 0) or 0) + + for global_doc_index, (doc_start_token, doc_end_token) in enumerate( + doc_ranges[doc_start:doc_end], + start=doc_start, + ): + target_tokens = doc_end_token - doc_start_token - 1 + if target_tokens <= 0: + continue + processed_docs += 1 + doc_baseline_loss = 0.0 + doc_ttt_loss = 0.0 + doc_tokens = 0.0 + doc_bytes = 0.0 + doc_chunks = 0 + doc_updates = 0 + chunk_specs: list[tuple[int, int]] = [] + position = 0 + while position < target_tokens: + next_position = _next_chunk_end( + position, + target_tokens, + args.p20_ttt_chunk_size, + scan_chunk_size, + ) + chunk_specs.append((position, next_position)) + position = next_position + + _reset_adapter_params(adapter_params, optimizer) + with torch.no_grad(): + for position, next_position in chunk_specs: + x, y, previous_ids, score_offset = _chunk_tensors( + val_tokens, + doc_start_token, + position, + next_position, + context_size, + device, + ) + score_loss = _score_chunk( + model, + x, + y, + score_offset, + device_type=device.type, + autocast_enabled=autocast_enabled, + reduction="sum", + ) + doc_baseline_loss += float(score_loss.item()) + doc_tokens += float(y.numel()) + doc_bytes += float( + _token_byte_count( + previous_ids, + y.reshape(-1), + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ).item() + ) + doc_chunks += 1 + + _reset_adapter_params(adapter_params, optimizer) + for position, next_position in chunk_specs: + x, y, _previous_ids, score_offset = _chunk_tensors( + val_tokens, + doc_start_token, + position, + next_position, + context_size, + device, + ) + with torch.no_grad(): + score_loss = _score_chunk( + model, + x, + y, + score_offset, + device_type=device.type, + autocast_enabled=autocast_enabled, + reduction="sum", + ) + doc_ttt_loss += float(score_loss.item()) + + if next_position < target_tokens: + optimizer.zero_grad(set_to_none=True) + update_loss = _score_chunk( + model, + x, + y, + score_offset, + device_type=device.type, + autocast_enabled=autocast_enabled, + reduction="mean", + ) + update_loss.backward() + if args.p20_ttt_grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(adapter_params, args.p20_ttt_grad_clip_norm) + optimizer.step() + doc_updates += 1 + + baseline_loss_sum += doc_baseline_loss + ttt_loss_sum += doc_ttt_loss + token_count_sum += doc_tokens + byte_count_sum += doc_bytes + chunk_count += doc_chunks + update_count += doc_updates + baseline_bpb = _bpb_from_loss_and_counts(doc_baseline_loss, doc_tokens, doc_bytes) + ttt_bpb = _bpb_from_loss_and_counts(doc_ttt_loss, doc_tokens, doc_bytes) + records.append( + { + "doc_index": global_doc_index, + "doc_start_token": doc_start_token, + "doc_end_token": doc_end_token, + "token_count": doc_tokens, + "byte_count": doc_bytes, + "chunk_count": doc_chunks, + "update_count": doc_updates, + "baseline_loss_sum": doc_baseline_loss, + "ttt_loss_sum": doc_ttt_loss, + "baseline_bpb": baseline_bpb, + "ttt_bpb": ttt_bpb, + "delta_bpb": ttt_bpb - baseline_bpb, + } + ) + finally: + if blocks: + _clear_p20_ttt_gate_params(blocks) + _restore_requires_grad(base_params, base_requires_grad) + model.train(was_training) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(baseline_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(ttt_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count_sum, op=dist.ReduceOp.SUM) + counts = torch.tensor([processed_docs, chunk_count, update_count], device=device, dtype=torch.float64) + dist.all_reduce(counts, op=dist.ReduceOp.SUM) + processed_docs = int(counts[0].item()) + chunk_count = int(counts[1].item()) + update_count = int(counts[2].item()) + gathered_records: list[list[dict[str, float | int]]] = [list() for _ in range(world_size)] + dist.all_gather_object(gathered_records, records) + records = [record for shard in gathered_records for record in shard] + records.sort(key=lambda record: int(record["doc_index"])) + + baseline_val_loss = float((baseline_loss_sum / token_count_sum).item()) + ttt_val_loss = float((ttt_loss_sum / token_count_sum).item()) + baseline_val_bpb = _bpb_from_loss_and_counts( + float(baseline_loss_sum.item()), + float(token_count_sum.item()), + float(byte_count_sum.item()), + ) + ttt_val_bpb = _bpb_from_loss_and_counts( + float(ttt_loss_sum.item()), + float(token_count_sum.item()), + float(byte_count_sum.item()), + ) + baseline_stats = P20TTTStats( + doc_count=processed_docs, + chunk_count=chunk_count, + update_count=0, + token_count=float(token_count_sum.item()), + byte_count=float(byte_count_sum.item()), + ) + ttt_stats = P20TTTStats( + doc_count=processed_docs, + chunk_count=chunk_count, + update_count=update_count, + token_count=float(token_count_sum.item()), + byte_count=float(byte_count_sum.item()), + ) + paired_stats = _summarize_paired_records( + records, + bootstrap_samples=bootstrap_samples, + bootstrap_seed=bootstrap_seed, + ) + return P20TTTPairedResult( + baseline_val_loss=baseline_val_loss, + baseline_val_bpb=baseline_val_bpb, + ttt_val_loss=ttt_val_loss, + ttt_val_bpb=ttt_val_bpb, + baseline_stats=baseline_stats, + ttt_stats=ttt_stats, + paired_stats=paired_stats, + per_doc=records, + ) + + +def eval_val_p20_gate_ttt( + args, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + *, + bos_token_id: int, +) -> tuple[float, float, P20TTTStats]: + mode = normalize_p20_ttt_mode(args.p20_ttt_mode) + if mode == "off": + raise ValueError("eval_val_p20_gate_ttt called with P20_TTT_MODE=off") + if args.p20_ttt_chunk_size <= 0: + raise ValueError(f"P20_TTT_CHUNK_SIZE must be positive, got {args.p20_ttt_chunk_size}") + + doc_ranges = _doc_ranges_from_bos(val_tokens, bos_token_id) + if args.p20_ttt_max_docs > 0: + doc_ranges = doc_ranges[: args.p20_ttt_max_docs] + doc_start = (len(doc_ranges) * rank) // world_size + doc_end = (len(doc_ranges) * (rank + 1)) // world_size + local_doc_ranges = doc_ranges[doc_start:doc_end] + + was_training = model.training + model.eval() + base_params = list(model.parameters()) + base_requires_grad = [param.requires_grad for param in base_params] + for param in base_params: + param.requires_grad_(False) + + blocks: list[nn.Module] = [] + adapter_params: list[Tensor] = [] + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + chunk_count = 0 + update_count = 0 + processed_docs = 0 + + try: + blocks, adapter_params = _set_p20_ttt_gate_params(model, mode, device) + optimizer = torch.optim.Adam( + adapter_params, + lr=args.p20_ttt_lr, + betas=(args.p20_ttt_beta1, args.p20_ttt_beta2), + weight_decay=args.p20_ttt_weight_decay, + ) + + context_size = _resolve_context_size(args) + autocast_enabled = device.type == "cuda" + scan_chunk_size = int(getattr(args, "p20_scan_chunk_size", 0) or 0) + for doc_start_token, doc_end_token in local_doc_ranges: + _reset_adapter_params(adapter_params, optimizer) + target_tokens = doc_end_token - doc_start_token - 1 + if target_tokens <= 0: + continue + processed_docs += 1 + position = 0 + while position < target_tokens: + next_position = _next_chunk_end( + position, + target_tokens, + args.p20_ttt_chunk_size, + scan_chunk_size, + ) + x, y, previous_ids, score_offset = _chunk_tensors( + val_tokens, + doc_start_token, + position, + next_position, + context_size, + device, + ) + + with torch.no_grad(): + score_loss = _score_chunk( + model, + x, + y, + score_offset, + device_type=device.type, + autocast_enabled=autocast_enabled, + reduction="sum", + ) + val_loss_sum += score_loss.to(torch.float64) + val_token_count += float(y.numel()) + val_byte_count += _token_byte_count( + previous_ids, + y.reshape(-1), + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + chunk_count += 1 + + if next_position < target_tokens: + optimizer.zero_grad(set_to_none=True) + update_loss = _score_chunk( + model, + x, + y, + score_offset, + device_type=device.type, + autocast_enabled=autocast_enabled, + reduction="mean", + ) + update_loss.backward() + if args.p20_ttt_grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(adapter_params, args.p20_ttt_grad_clip_norm) + optimizer.step() + update_count += 1 + + position = next_position + finally: + if blocks: + _clear_p20_ttt_gate_params(blocks) + _restore_requires_grad(base_params, base_requires_grad) + model.train(was_training) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + counts = torch.tensor([processed_docs, chunk_count, update_count], device=device, dtype=torch.float64) + dist.all_reduce(counts, op=dist.ReduceOp.SUM) + processed_docs = int(counts[0].item()) + chunk_count = int(counts[1].item()) + update_count = int(counts[2].item()) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + stats = P20TTTStats( + doc_count=processed_docs, + chunk_count=chunk_count, + update_count=update_count, + token_count=float(val_token_count.item()), + byte_count=float(val_byte_count.item()), + ) + return float(val_loss.item()), float(bits_per_token * tokens_per_byte), stats diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/submission.json b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/submission.json new file mode 100644 index 0000000000..78cbe4e629 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/submission.json @@ -0,0 +1,18 @@ +{ + "author": "Joseph Abraham", + "github_id": "abbudjoe", + "name": "Fractal Recurrent Primitive Hybrid (1xH100, SP1024)", + "blurb": "Non-record research submission: a custom Fractal recurrent primitive inserted as a single middle layer in an 11L/512 SP1024 transformer. A 10-minute 1xH100 run reached 1.3562 pre-quant BPB; checkpoint requant with all-large-int8 reached 1.3576 BPB at 14.44MB. Includes pure-attention quant control showing the recurrent path is close but not yet a record lane.", + "date": "2026-04-12T10:10:00Z", + "track": "non_record_16mb", + "val_loss": 2.2922833208646662, + "val_bpb": 1.357619300696428, + "pre_quant_val_loss": 2.2899228062501846, + "pre_quant_val_bpb": 1.3562212709803687, + "step_stop": 988, + "wallclock_seconds": 600.491, + "bytes_total": 14440584, + "bytes_model_int8_zlib": 14208732, + "bytes_code": 231852, + "gpu": "1xH100 80GB" +} diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/supporting_files/requant_checkpoint_sweep.py b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/supporting_files/requant_checkpoint_sweep.py new file mode 100644 index 0000000000..206c835d15 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/supporting_files/requant_checkpoint_sweep.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +"""Run quantization-only sweeps against an existing Parameter Golf checkpoint.""" + +from __future__ import annotations + +import argparse +import io +import json +import os +import sys +import time +from pathlib import Path +from typing import Any + + +def _set_env_defaults_from_summary(summary: dict[str, Any]) -> None: + mappings = { + "VOCAB_SIZE": "vocab_size", + "NUM_LAYERS": "num_layers", + "NUM_KV_HEADS": "num_kv_heads", + "MODEL_DIM": "model_dim", + "NUM_HEADS": "num_heads", + "MLP_MULT": "mlp_mult", + "MLP_HIDDEN_DIM": "mlp_hidden_dim", + "MODEL_FAMILY": "model_family", + "P20_RUNTIME_BACKEND": "p20_runtime_backend", + "P20_LAYER_SCHEDULE": "p20_layer_schedule", + "P20_STATE_BLOCKS": "p20_state_blocks_requested", + "P20_BLOCK_PAIR_WIDTH_CAP": "p20_block_pair_width_cap", + "P20_SCAN_CHUNK_SIZE": "p20_scan_chunk_size", + "P20_ADAPTER_DIM": "p20_adapter_dim", + "P20_ADAPTER_SCALE_INIT": "p20_adapter_scale_init", + "P20_LOOP_REPEATS": "p20_loop_repeats", + "P20_PRIMITIVE_LOOP_REPEATS": "p20_primitive_loop_repeats", + "P20_PRIMITIVE_LOOP_DELTA_SCALE": "p20_primitive_loop_delta_scale", + "LOOP_DELTA_SCALE": "loop_delta_scale", + "TRAIN_SEQ_LEN": "train_seq_len", + "VAL_BATCH_SIZE": "val_batch_size", + "QUANT_COMPRESSION": "quant_compression", + } + for env_name, key in mappings.items(): + value = summary.get(key) + if value is None or value == "": + continue + os.environ.setdefault(env_name, str(value)) + + if "tie_embeddings" in summary: + os.environ.setdefault("TIE_EMBEDDINGS", "1" if summary["tie_embeddings"] else "0") + os.environ.setdefault("QUANT_FORMAT", "mixed_int6_clipsearch") + os.environ.setdefault("MIXED_INT6_CLIP_QUANTILES", "0.999,0.9995,0.9999,0.99999,1.0") + if "compiled_model" in summary: + os.environ.setdefault("COMPILE_MODEL", "1" if summary["compiled_model"] else "0") + else: + os.environ.setdefault("COMPILE_MODEL", "0") + os.environ.setdefault("COMPILE_MUON_BACKEND", "0") + + +def _parse_categories(raw: str) -> frozenset[str]: + return frozenset(part.strip() for part in raw.split(",") if part.strip()) + + +def _variant_contracts(default_patterns: tuple[str, ...]) -> list[dict[str, Any]]: + p20_small_fp32_patterns = tuple( + dict.fromkeys( + ( + *default_patterns, + "residual_scale", + "input_norm", + "output_norm", + "primitive.in_projection.projection.bias", + "primitive.state_transform_projection.bias", + "readout_projection.bias", + ) + ) + ) + return [ + { + "name": "p20_int8_default", + "mixed_int6_categories": "attn,mlp", + "fp32_patterns": default_patterns, + "note": "Reproduce the Step 3 leader: attention/MLP int6, P20 large matrices int8, small tensors passthrough.", + }, + { + "name": "p20_int8_p20_small_fp32", + "mixed_int6_categories": "attn,mlp", + "fp32_patterns": p20_small_fp32_patterns, + "note": "Keep P20 residual/norm/bias tensors in fp32 instead of fp16 passthrough.", + }, + { + "name": "attn_int8_mlp_int6_p20_int8", + "mixed_int6_categories": "mlp", + "fp32_patterns": default_patterns, + "note": "Spend bytes protecting attention and P20 at int8 while leaving MLP at int6.", + }, + { + "name": "attn_int6_mlp_int8_p20_int8", + "mixed_int6_categories": "attn", + "fp32_patterns": default_patterns, + "note": "Spend bytes protecting MLP and P20 at int8 while leaving attention at int6.", + }, + { + "name": "all_large_int8", + "mixed_int6_categories": "", + "fp32_patterns": default_patterns, + "note": "Use int8 for all large tensors; this is the upper-bound quantization-protection check.", + }, + ] + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--raw-model-path", required=True) + parser.add_argument("--source-summary-path", required=True) + parser.add_argument("--out-dir", required=True) + parser.add_argument("--run-id", required=True) + args_ns = parser.parse_args() + + source_summary_path = Path(args_ns.source_summary_path) + source_summary = json.loads(source_summary_path.read_text(encoding="utf-8")) + _set_env_defaults_from_summary(source_summary) + + repo_root = Path(__file__).resolve().parents[3] + sys.path.insert(0, str(repo_root)) + + import torch + import train_gpt + + out_dir = Path(args_ns.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + summary_path = out_dir / "summary.json" + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + + args = train_gpt.Hyperparameters() + sp = train_gpt.spm.SentencePieceProcessor() + sp.load(args.tokenizer_path) + val_tokens = train_gpt.load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = train_gpt.build_sentencepiece_luts( + sp, + args.vocab_size, + device, + ) + + base_model = train_gpt.build_model(args).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, train_gpt.CastedLinear): + module.float() + train_gpt.restore_low_dim_params_to_fp32(base_model) + use_compiled_model = train_gpt.resolve_compile_model(args) + eval_model = ( + torch.compile(base_model, dynamic=False, fullgraph=args.model_family == "baseline") + if use_compiled_model + else base_model + ) + + raw_model_path = Path(args_ns.raw_model_path) + state = torch.load(raw_model_path, map_location="cpu") + base_model.load_state_dict(state, strict=True) + train_gpt.restore_low_dim_params_to_fp32(base_model) + state = {name: tensor.detach().cpu().contiguous() for name, tensor in base_model.state_dict().items()} + + torch.cuda.synchronize() + pre_t0 = time.perf_counter() + pre_loss, pre_bpb = train_gpt.eval_val( + args, + eval_model, + 0, + 1, + device, + 8, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + pre_eval_ms = 1000.0 * (time.perf_counter() - pre_t0) + + code_bytes, code_files = train_gpt.collect_submission_code_bytes(Path(train_gpt.__file__).resolve()) + default_fp32_patterns = train_gpt.INT8_KEEP_FLOAT_FP32_NAME_PATTERNS + variants: list[dict[str, Any]] = [] + + for contract in _variant_contracts(default_fp32_patterns): + train_gpt.INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple(contract["fp32_patterns"]) + categories = _parse_categories(contract["mixed_int6_categories"]) + torch.cuda.synchronize() + quant_t0 = time.perf_counter() + quant_obj, quant_stats = train_gpt.quantize_state_dict_mixed_int6_clipsearch( + state, + int6_categories=categories, + clip_quantiles=args.mixed_int6_clip_quantiles, + ) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = train_gpt.compress_quant_payload(quant_raw, args.quant_compression) + quant_path = out_dir / f"{args_ns.run_id}_{contract['name']}.int6clip.{args.quant_compression}.ptz" + quant_path.write_bytes(quant_blob) + quant_time_ms = 1000.0 * (time.perf_counter() - quant_t0) + + roundtrip_obj = torch.load( + io.BytesIO(train_gpt.decompress_quant_payload(quant_blob, args.quant_compression)), + map_location="cpu", + ) + base_model.load_state_dict(train_gpt.dequantize_state_dict_export(roundtrip_obj), strict=True) + train_gpt.restore_low_dim_params_to_fp32(base_model) + + torch.cuda.synchronize() + eval_t0 = time.perf_counter() + post_loss, post_bpb = train_gpt.eval_val( + args, + eval_model, + 0, + 1, + device, + 8, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + post_eval_ms = 1000.0 * (time.perf_counter() - eval_t0) + + variant = { + "name": contract["name"], + "note": contract["note"], + "mixed_int6_categories": sorted(categories), + "fp32_patterns": list(contract["fp32_patterns"]), + "post_quant_val_loss": post_loss, + "post_quant_val_bpb": post_bpb, + "post_minus_pre_bpb": post_bpb - pre_bpb, + "post_quant_eval_time_ms": post_eval_ms, + "quant_time_ms": quant_time_ms, + "quant_file_bytes": quant_path.stat().st_size, + "quant_raw_bytes": len(quant_raw), + "total_submission_bytes_quantized": quant_path.stat().st_size + code_bytes, + "quant_path": str(quant_path), + "quant_stats": quant_stats, + } + variants.append(variant) + print("variant " + json.dumps(variant, sort_keys=True), flush=True) + + best = min(variants, key=lambda item: item["post_quant_val_bpb"]) + summary = { + "event": "checkpoint_requant_sweep_completed", + "experiment_kind": "step4_quantization_protection_sweep", + "run_id": args_ns.run_id, + "source_run_id": source_summary.get("run_id"), + "raw_model_path": str(raw_model_path), + "source_summary_path": str(source_summary_path), + "code_bytes": code_bytes, + "code_files": code_files, + "compiled_model": use_compiled_model, + "model_family": args.model_family, + "vocab_size": args.vocab_size, + "num_layers": args.num_layers, + "model_dim": args.model_dim, + "num_heads": args.num_heads, + "num_kv_heads": args.num_kv_heads, + "mlp_mult": args.mlp_mult, + "mlp_hidden_dim": args.mlp_hidden_dim if args.mlp_hidden_dim > 0 else args.mlp_mult * args.model_dim, + "p20_layer_schedule": args.p20_layer_schedule, + "p20_runtime_backend": args.p20_runtime_backend, + "p20_state_blocks": args.p20_state_blocks, + "p20_block_pair_width_cap": args.p20_block_pair_width_cap, + "pre_quant_val_loss": pre_loss, + "pre_quant_val_bpb": pre_bpb, + "pre_quant_eval_time_ms": pre_eval_ms, + "quant_compression": args.quant_compression, + "mixed_int6_clip_quantiles": list(args.mixed_int6_clip_quantiles), + "best_variant": best["name"], + "best_post_quant_val_bpb": best["post_quant_val_bpb"], + "best_post_quant_val_loss": best["post_quant_val_loss"], + "best_total_submission_bytes_quantized": best["total_submission_bytes_quantized"], + "variants": variants, + } + summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8") + print("summary_path " + str(summary_path), flush=True) + print(json.dumps(summary, sort_keys=True), flush=True) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + raise + except Exception as exc: + print(f"error: {exc}", file=sys.stderr) + raise diff --git a/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/train_gpt.py b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/train_gpt.py new file mode 100644 index 0000000000..038e10796d --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_Fractal_RecurrentPrimitive_SP1024_1xH100/train_gpt.py @@ -0,0 +1,3521 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import json +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +try: + import zstandard as zstd +except ImportError: # Optional; zlib/lzma remain dependency-free fallbacks. + zstd = None +from p20_runtime import P20RuntimeSequenceMixer, validate_p20_runtime_backend +from p20_ttt import eval_val_doc_chunks, eval_val_p20_gate_ttt, eval_val_p20_gate_ttt_paired, normalize_p20_ttt_mode +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + val_max_batches = int(os.environ.get("VAL_MAX_BATCHES", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + sw_eval_batch = int(os.environ.get("SW_EVAL_BATCH", 32)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden_dim = int(os.environ.get("MLP_HIDDEN_DIM", "0")) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", "0")) + bigram_dim = int(os.environ.get("BIGRAM_DIM", "64")) + bigram_scale_init = float(os.environ.get("BIGRAM_SCALE_INIT", "1.0")) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + model_family = os.environ.get("MODEL_FAMILY", "baseline") + p20_runtime_backend = os.environ.get("P20_RUNTIME_BACKEND", "torch") + p20_layer_schedule = os.environ.get("P20_LAYER_SCHEDULE", "") + p20_state_blocks = os.environ.get("P20_STATE_BLOCKS", "2") + p20_block_pair_width_cap = int(os.environ.get("P20_BLOCK_PAIR_WIDTH_CAP", "32")) + p20_scan_chunk_size = int(os.environ.get("P20_SCAN_CHUNK_SIZE", "0")) + p20_adapter_dim = int(os.environ.get("P20_ADAPTER_DIM", "0")) + p20_adapter_scale_init = float(os.environ.get("P20_ADAPTER_SCALE_INIT", "0.5")) + p20_loop_repeats = int(os.environ.get("P20_LOOP_REPEATS", "1")) + p20_primitive_loop_repeats = int(os.environ.get("P20_PRIMITIVE_LOOP_REPEATS", "1")) + p20_primitive_loop_repeats_by_p = os.environ.get("P20_PRIMITIVE_LOOP_REPEATS_BY_P", "") + p20_primitive_loop_delta_scale = float(os.environ.get("P20_PRIMITIVE_LOOP_DELTA_SCALE", "1.0")) + p20_ttt_mode = os.environ.get("P20_TTT_MODE", "off") + p20_ttt_chunk_size = int(os.environ.get("P20_TTT_CHUNK_SIZE", "64")) + p20_ttt_context_size = int(os.environ.get("P20_TTT_CONTEXT_SIZE", "0")) + p20_ttt_lr = float(os.environ.get("P20_TTT_LR", "0.01")) + p20_ttt_beta1 = float(os.environ.get("P20_TTT_BETA1", "0.0")) + p20_ttt_beta2 = float(os.environ.get("P20_TTT_BETA2", "0.999")) + p20_ttt_weight_decay = float(os.environ.get("P20_TTT_WEIGHT_DECAY", "0.0")) + p20_ttt_grad_clip_norm = float(os.environ.get("P20_TTT_GRAD_CLIP_NORM", "0.0")) + p20_ttt_max_docs = int(os.environ.get("P20_TTT_MAX_DOCS", "0")) + p20_ttt_bootstrap_samples = int(os.environ.get("P20_TTT_BOOTSTRAP_SAMPLES", "2000")) + p20_ttt_bootstrap_seed = int(os.environ.get("P20_TTT_BOOTSTRAP_SEED", "12345")) + p20_ttt_per_doc_path = os.environ.get("P20_TTT_PER_DOC_PATH", "") + baseline_loop_repeats = int(os.environ.get("BASELINE_LOOP_REPEATS", "1")) + baseline_loop_layer_indices = os.environ.get("BASELINE_LOOP_LAYER_INDICES", "") + loop_delta_scale = float(os.environ.get("LOOP_DELTA_SCALE", "1.0")) + compile_model = os.environ.get("COMPILE_MODEL", "auto") + compile_prewarm_steps = int(os.environ.get("COMPILE_PREWARM_STEPS", "0")) + compile_prewarm_counts_toward_training = bool(int(os.environ.get("COMPILE_PREWARM_COUNTS_TOWARD_TRAINING", "1"))) + mamba_layer_schedule = os.environ.get("MAMBA_LAYER_SCHEDULE", "") + mamba_d_state = int(os.environ.get("MAMBA_D_STATE", "64")) + mamba_expand = int(os.environ.get("MAMBA_EXPAND", "1")) + mamba_ngroups = int(os.environ.get("MAMBA_NGROUPS", "1")) + mamba_chunk_size = int(os.environ.get("MAMBA_CHUNK_SIZE", "16")) + mamba_is_mimo = bool(int(os.environ.get("MAMBA_IS_MIMO", "0"))) + mamba_mimo_rank = int(os.environ.get("MAMBA_MIMO_RANK", "1")) + mamba_include_mlp = bool(int(os.environ.get("MAMBA_INCLUDE_MLP", "1"))) + mamba_residual_scale_init = float(os.environ.get("MAMBA_RESIDUAL_SCALE_INIT", "1.0")) + recurrence_encoder_schedule = os.environ.get("RECURRENCE_ENCODER_SCHEDULE", "") + recurrence_decoder_schedule = os.environ.get("RECURRENCE_DECODER_SCHEDULE", "") + recurrence_start_fraction = float(os.environ.get("RECURRENCE_START_FRACTION", "0.0")) + recurrence_start_step = int(os.environ.get("RECURRENCE_START_STEP", "-1")) + recurrence_aux_kind = os.environ.get("RECURRENCE_AUX_KIND", "none") + recurrence_aux_mode = os.environ.get("RECURRENCE_AUX_MODE", "none") + recurrence_aux_applications = os.environ.get("RECURRENCE_AUX_APPLICATIONS", "") + recurrence_aux_scale_init = float(os.environ.get("RECURRENCE_AUX_SCALE_INIT", "0.5")) + compile_muon_backend = os.environ.get("COMPILE_MUON_BACKEND", "auto") + + # Optimizer hyperparameters. + optimizer_family = os.environ.get("OPTIMIZER_FAMILY", "muon").strip().lower() + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + p20_matrix_lr = float(os.environ["P20_MATRIX_LR"]) if os.environ.get("P20_MATRIX_LR", "") else None + p20_scalar_lr = float(os.environ["P20_SCALAR_LR"]) if os.environ.get("P20_SCALAR_LR", "") else None + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + ema_start_step = int(os.environ.get("EMA_START_STEP", "0")) + quant_format = os.environ.get("QUANT_FORMAT", "int8_zlib").strip().lower() + quant_compression = os.environ.get("QUANT_COMPRESSION", "zlib").strip().lower() + min_train_shards = int(os.environ.get("MIN_TRAIN_SHARDS", "0")) + min_train_tokens = int(os.environ.get("MIN_TRAIN_TOKENS", "0")) + mixed_int6_categories = frozenset( + category.strip() + for category in os.environ.get("MIXED_INT6_CATEGORIES", "attn,mlp").split(",") + if category.strip() + ) + mixed_int6_clip_quantiles = tuple( + float(value) + for value in os.environ.get("MIXED_INT6_CLIP_QUANTILES", "0.999,0.9995,0.9999,0.99999,1.0").split(",") + if value.strip() + ) + out_dir = os.environ.get("OUT_DIR", "logs") + run_summary_path = os.environ.get("RUN_SUMMARY_PATH", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + # Keep eval cheap without creating inference tensors that can poison a + # compiled callable when the same model is used for later training steps. + with torch.no_grad(): + processed_batches = 0 + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + if args.val_max_batches > 0 and processed_batches >= args.val_max_batches: + break + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + processed_batches += 1 + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_window( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + if args.eval_stride <= 0: + raise ValueError(f"EVAL_STRIDE must be positive for sliding-window eval, got {args.eval_stride}") + seq_len = args.train_seq_len + stride = args.eval_stride + total_tokens = val_tokens.numel() - 1 + if total_tokens < seq_len: + raise ValueError(f"Validation split is too short for sliding eval seq_len={seq_len}") + + num_windows = max((total_tokens - seq_len) // stride + 1, 1) + win_start = (num_windows * rank) // world_size + win_end = (num_windows * (rank + 1)) // world_size + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + # Use no_grad rather than inference_mode so compiled models can be reused + # safely if evaluation is followed by additional training. + with torch.no_grad(): + window_list = list(range(win_start, win_end)) + num_batches = (len(window_list) + args.sw_eval_batch - 1) // args.sw_eval_batch + if args.val_max_batches > 0: + num_batches = min(num_batches, args.val_max_batches) + for batch_idx in range(num_batches): + batch_wins = window_list[batch_idx * args.sw_eval_batch : (batch_idx + 1) * args.sw_eval_batch] + if not batch_wins: + continue + inputs = torch.stack( + [val_tokens[w * stride : w * stride + seq_len] for w in batch_wins] + ).to(device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(inputs) + + scored_logits = logits[:, -stride:, :].reshape(-1, logits.size(-1)) + targets = torch.stack( + [val_tokens[w * stride + seq_len - stride + 1 : w * stride + seq_len + 1] for w in batch_wins] + ).to(device=device, dtype=torch.int64).reshape(-1) + + loss = F.cross_entropy(scored_logits.float(), targets, reduction="sum") + val_loss_sum += loss.to(torch.float64) + val_token_count += float(targets.numel()) + + prev_ids = torch.stack( + [val_tokens[w * stride + seq_len - stride : w * stride + seq_len] for w in batch_wins] + ).to(device=device, dtype=torch.int64).reshape(-1) + token_bytes = base_bytes_lut[targets].to(dtype=torch.int16) + token_bytes += ( + has_leading_space_lut[targets] & ~is_boundary_token_lut[prev_ids] + ).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def write_run_summary(path: Path, summary: dict[str, object]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def write_jsonl(path: Path, records: list[dict[str, object]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for record in records: + f.write(json.dumps(record, sort_keys=True) + "\n") + + +def collect_submission_code_bytes(script_path: Path) -> tuple[int, list[dict[str, object]]]: + code_files = [script_path] + for sibling_name in ("p20_runtime.py", "p20_ttt.py"): + sibling = script_path.with_name(sibling_name) + if sibling.is_file(): + code_files.append(sibling) + entries = [] + total = 0 + for path in code_files: + byte_count = len(path.read_bytes()) + entries.append({"path": path.name, "bytes": byte_count}) + total += byte_count + return total, entries + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,bigram.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantization_category(name: str) -> str: + if name.startswith("tok_emb") or name.startswith("lm_head") or ".tok_emb" in name or ".lm_head" in name: + return "embed" + if ".attn." in name: + return "attn" + if ".mlp." in name: + return "mlp" + if ".primitive." in name or ".readout_projection." in name or ".p20_adapter_" in name: + return "p20" + if ".bigram." in name or name.startswith("bigram."): + return "bigram" + return "other" + +def quantize_int6_clipsearch_tensor( + t: Tensor, + *, + clip_quantiles: tuple[float, ...], + clip_range: int = 31, +) -> tuple[Tensor, Tensor, float]: + t32 = t.float() + if t32.ndim == 2: + best_q: Tensor | None = None + best_scale: Tensor | None = None + best_error = float("inf") + for quantile in clip_quantiles or (1.0,): + row_clip = torch.quantile(t32.abs(), quantile, dim=1) if quantile < 1.0 else t32.abs().amax(dim=1) + scale = (row_clip / float(clip_range)).clamp_min(1.0 / float(clip_range)) + q = torch.clamp(torch.round(t32 / scale[:, None]), -clip_range, clip_range).to(torch.int8).contiguous() + recon = q.float() * scale[:, None] + error = float((t32 - recon).pow(2).mean().item()) + if error < best_error: + best_q = q + best_scale = scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + best_error = error + if best_q is None or best_scale is None: + raise RuntimeError("int6 clip-search failed to produce a quantized tensor") + return best_q, best_scale, best_error + + amax = float(t32.abs().max().item()) if t32.numel() else 0.0 + scale = torch.tensor(amax / float(clip_range) if amax > 0 else 1.0, dtype=INT8_PER_ROW_SCALE_DTYPE) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8).contiguous() + error = float((t32 - q.float() * scale.float()).pow(2).mean().item()) if t32.numel() else 0.0 + return q, scale, error + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def quantize_state_dict_mixed_int6_clipsearch( + state_dict: dict[str, Tensor], + *, + int6_categories: frozenset[str], + clip_quantiles: tuple[float, ...], +): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ( + "param_count", + "num_tensors", + "num_float_tensors", + "num_nonfloat_tensors", + "baseline_tensor_bytes", + "quantized_payload_bytes", + "int6_tensor_count", + "int8_tensor_count", + "passthrough_tensor_count", + ), + 0, + ) + stats["int6_categories"] = sorted(int6_categories) + stats["clip_quantiles"] = list(clip_quantiles) + stats["clipsearch_mse_sum"] = 0.0 + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + qmeta[name] = {"type": "passthrough"} + stats["quantized_payload_bytes"] += tensor_nbytes(t) + stats["passthrough_tensor_count"] += 1 + continue + + stats["num_float_tensors"] += 1 + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL or any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + qmeta[name] = {"type": "passthrough_float"} + stats["quantized_payload_bytes"] += tensor_nbytes(kept) + stats["passthrough_tensor_count"] += 1 + continue + + category = quantization_category(name) + if category in int6_categories: + q, scale, mse = quantize_int6_clipsearch_tensor(t, clip_quantiles=clip_quantiles) + qtype = "int6_clipsearch" + stats["int6_tensor_count"] += 1 + stats["clipsearch_mse_sum"] += mse + else: + q, scale = quantize_float_tensor(t) + qtype = "int8" + stats["int8_tensor_count"] += 1 + quantized[name] = q + scales[name] = scale + dtypes[name] = str(t.dtype).removeprefix("torch.") + qmeta[name] = {"type": qtype, "category": category} + stats["quantized_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(scale) + + obj: dict[str, object] = { + "__quant_format__": "mixed_int6_clipsearch_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + "qmeta": qmeta, + "stats": stats, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +def dequantize_state_dict_mixed_int6_clipsearch(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + scale = obj["scales"][name].to(dtype=torch.float32) + if scale.ndim > 0 and q.ndim >= 2: + value = q.float() * scale.view(q.shape[0], *([1] * (q.ndim - 1))) + else: + value = q.float() * float(scale.item()) + out[name] = value.to(dtype=dtype).contiguous() + return out + +def dequantize_state_dict_export(obj: dict[str, object]) -> dict[str, Tensor]: + if obj.get("__quant_format__") == "mixed_int6_clipsearch_v1": + return dequantize_state_dict_mixed_int6_clipsearch(obj) + return dequantize_state_dict_int8(obj) + +def compress_quant_payload(payload: bytes, compression: str) -> bytes: + if compression == "zlib": + return zlib.compress(payload, level=9) + if compression == "lzma": + return lzma.compress(payload, preset=9 | lzma.PRESET_EXTREME) + if compression == "zstd": + if zstd is None: + raise RuntimeError("QUANT_COMPRESSION=zstd requires the optional zstandard package") + return zstd.ZstdCompressor(level=22).compress(payload) + raise ValueError(f"QUANT_COMPRESSION must be zlib|lzma|zstd, got {compression}") + +def decompress_quant_payload(payload: bytes, compression: str) -> bytes: + if compression == "zlib": + return zlib.decompress(payload) + if compression == "lzma": + return lzma.decompress(payload) + if compression == "zstd": + if zstd is None: + raise RuntimeError("QUANT_COMPRESSION=zstd requires the optional zstandard package") + return zstd.ZstdDecompressor().decompress(payload) + raise ValueError(f"QUANT_COMPRESSION must be zlib|lzma|zstd, got {compression}") + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def data_shard_token_count(file: Path) -> int: + header = np.fromfile(file, dtype=" int: + return sum(data_shard_token_count(file) for file in files) + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class CastedLayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + weight = self.weight.to(dtype=x.dtype) if self.weight is not None else None + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + return F.layer_norm(x, self.normalized_shape, weight, bias, self.eps) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +def torch_version_prefix() -> tuple[int, ...]: + version = torch.__version__.split("+", 1)[0] + parts = [] + for piece in version.split("."): + digits = "".join(ch for ch in piece if ch.isdigit()) + if not digits: + break + parts.append(int(digits)) + return tuple(parts) + + +def env_bool(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return default + return value.lower() not in {"0", "false", "no"} + + +DEFAULT_SDPA_SUPPORTS_ENABLE_GQA = torch_version_prefix() >= (2, 5) +SDPA_SUPPORTS_ENABLE_GQA = env_bool( + "SDPA_SUPPORTS_ENABLE_GQA", + DEFAULT_SDPA_SUPPORTS_ENABLE_GQA, +) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.kv_head_repeats = self.num_heads // self.num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.num_kv_heads != self.num_heads and not SDPA_SUPPORTS_ENABLE_GQA: + k = k.repeat_interleave(self.kv_head_repeats, dim=1) + v = v.repeat_interleave(self.kv_head_repeats, dim=1) + sdpa_kwargs = { + "attn_mask": None, + "is_causal": True, + } + if SDPA_SUPPORTS_ENABLE_GQA: + sdpa_kwargs["enable_gqa"] = self.num_kv_heads != self.num_heads + y = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden_dim: int = 0): + super().__init__() + hidden = mlp_hidden_dim if mlp_hidden_dim > 0 else mlp_mult * dim + if hidden <= 0: + raise ValueError(f"MLP hidden dimension must be positive, got {hidden}") + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class BigramHashEmbedding(nn.Module): + def __init__(self, vocab_size: int, dim: int, model_dim: int, scale_init: float): + super().__init__() + if vocab_size <= 1: + raise ValueError(f"BIGRAM_VOCAB_SIZE must be 0 or greater than 1, got {vocab_size}") + if dim <= 0: + raise ValueError(f"BIGRAM_DIM must be positive, got {dim}") + self.vocab_size = vocab_size + self.embed = nn.Embedding(vocab_size, dim) + self.proj = CastedLinear(dim, model_dim, bias=False) if dim != model_dim else None + self.scale = nn.Parameter(torch.full((model_dim,), scale_init, dtype=torch.float32)) + nn.init.zeros_(self.embed.weight) + + def forward(self, input_ids: Tensor) -> Tensor: + prev = F.pad(input_ids[:, :-1], (1, 0), value=0).long() + current = input_ids.long() + # Reserve bucket 0 for the artificial BOS previous-token context. + hashed = ((prev * 1_000_003 + current) % (self.vocab_size - 1)) + 1 + hashed[:, 0] = 0 + out = self.embed(hashed) + if self.proj is not None: + out = self.proj(out) + return out * self.scale.to(dtype=out.dtype)[None, None, :] + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden_dim: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, mlp_hidden_dim) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class MambaSequenceBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_mult: int, + mlp_hidden_dim: int, + *, + d_state: int, + expand: int, + ngroups: int, + chunk_size: int, + is_mimo: bool, + mimo_rank: int, + include_mlp: bool, + residual_scale_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads for Mamba3 headdim") + if d_state <= 0 or expand <= 0 or ngroups <= 0 or chunk_size <= 0: + raise ValueError( + "MAMBA_D_STATE, MAMBA_EXPAND, MAMBA_NGROUPS, and MAMBA_CHUNK_SIZE must be positive" + ) + if is_mimo and mimo_rank < 2: + raise ValueError("MAMBA_IS_MIMO=1 requires MAMBA_MIMO_RANK >= 2") + if not is_mimo and mimo_rank != 1: + raise ValueError("MAMBA_IS_MIMO=0 requires MAMBA_MIMO_RANK=1") + try: + from mamba_ssm import Mamba3 + except Exception as exc: + raise RuntimeError( + "MODEL_FAMILY=mamba_hybrid requires the official mamba_ssm package. " + "Bootstrap the CUDA Mamba environment before running this model." + ) from exc + + self.input_norm = RMSNorm() + self.output_norm = RMSNorm() + self.include_mlp = include_mlp + mamba_kwargs = { + "d_model": dim, + "d_state": d_state, + "expand": expand, + "headdim": dim // num_heads, + "ngroups": ngroups, + "is_mimo": is_mimo, + "mimo_rank": mimo_rank, + "chunk_size": chunk_size, + "is_outproj_norm": False, + "dtype": torch.bfloat16, + } + try: + self.mixer = Mamba3(**mamba_kwargs) + except TypeError as exc: + if "dtype" not in str(exc): + raise + mamba_kwargs.pop("dtype") + self.mixer = Mamba3(**mamba_kwargs) + self.mlp = MLP(dim, mlp_mult, mlp_hidden_dim) if include_mlp else None + self.mixer_scale = nn.Parameter(torch.full((dim,), residual_scale_init, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + mixed = self.mixer(self.input_norm(x)) + x = x + self.mixer_scale.to(dtype=x.dtype)[None, None, :] * mixed + if self.mlp is not None: + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.output_norm(x)) + return x + + +class P20PrimitiveBlock(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + mlp_hidden_dim: int, + *, + runtime_backend: str, + state_transform_blocks: int, + scan_chunk_size: int = 0, + adapter_dim: int = 0, + adapter_scale_init: float = 0.5, + loop_repeats: int = 1, + loop_delta_scale: float = 1.0, + primitive_loop_repeats: int = 1, + primitive_loop_delta_scale: float = 1.0, + ): + super().__init__() + if loop_repeats <= 0: + raise ValueError(f"P20 loop_repeats must be positive, got {loop_repeats}") + if loop_delta_scale <= 0.0: + raise ValueError(f"P20 loop_delta_scale must be positive, got {loop_delta_scale}") + if primitive_loop_repeats <= 0: + raise ValueError(f"P20 primitive_loop_repeats must be positive, got {primitive_loop_repeats}") + if primitive_loop_delta_scale <= 0.0: + raise ValueError( + f"P20 primitive_loop_delta_scale must be positive, got {primitive_loop_delta_scale}" + ) + if adapter_dim < 0: + raise ValueError(f"P20_ADAPTER_DIM must be non-negative, got {adapter_dim}") + self.loop_repeats = loop_repeats + self.loop_delta_scale = loop_delta_scale + self.primitive_loop_repeats = primitive_loop_repeats + self.primitive_loop_delta_scale = primitive_loop_delta_scale + self.input_norm = CastedLayerNorm(dim) + self.output_norm = CastedLayerNorm(dim) + self.primitive = P20RuntimeSequenceMixer( + dim, + state_transform_blocks=state_transform_blocks, + runtime_backend=runtime_backend, + scan_chunk_size=scan_chunk_size, + ) + self.readout_projection = CastedLinear(dim, dim, bias=True) + self.residual_scale = nn.Parameter(torch.full((dim,), 0.5, dtype=torch.float32)) + self.p20_ttt_residual_gate: Tensor | None = None + self.p20_adapter_down: CastedLinear | None = None + self.p20_adapter_up: CastedLinear | None = None + self.p20_adapter_scale: nn.Parameter | None = None + if adapter_dim > 0: + self.p20_adapter_down = CastedLinear(dim, adapter_dim, bias=False) + self.p20_adapter_up = CastedLinear(adapter_dim, dim, bias=False) + self.p20_adapter_up._zero_init = True + self.p20_adapter_scale = nn.Parameter(torch.full((dim,), adapter_scale_init, dtype=torch.float32)) + self.mlp = MLP(dim, mlp_mult, mlp_hidden_dim) + + def _run_primitive_loop(self, normed: Tensor) -> Tensor: + mixed = self.primitive(normed) + for _ in range(1, self.primitive_loop_repeats): + next_mixed = self.primitive(mixed) + if self.primitive_loop_delta_scale == 1.0: + mixed = next_mixed + else: + mixed = mixed + (next_mixed - mixed) * self.primitive_loop_delta_scale + return mixed + + def _forward_once(self, x: Tensor) -> Tensor: + for _ in range(self.loop_repeats): + normed = self.input_norm(x) + mixed = self._run_primitive_loop(normed) + readout = self.readout_projection(mixed) + if self.p20_adapter_down is not None and self.p20_adapter_up is not None: + adapter = torch.relu(self.p20_adapter_down(normed)).square() + adapter = self.p20_adapter_up(adapter) + if self.p20_adapter_scale is not None: + adapter = adapter * self.p20_adapter_scale.to(dtype=x.dtype)[None, None, :] + readout = readout + adapter + residual_scale = self.residual_scale + if self.p20_ttt_residual_gate is not None: + gate = 1.0 + torch.tanh(self.p20_ttt_residual_gate) + residual_scale = residual_scale * gate + next_x = x + readout * residual_scale.to(dtype=x.dtype)[None, None, :] + next_x = next_x + self.mlp(self.output_norm(next_x)) + if self.loop_repeats == 1 and self.loop_delta_scale == 1.0: + x = next_x + else: + x = x + (next_x - x) * self.loop_delta_scale + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_once(x) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden_dim: int, + bigram_vocab_size: int, + bigram_dim: int, + bigram_scale_init: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + loop_repeats: int = 1, + loop_layer_indices: frozenset[int] = frozenset(), + loop_delta_scale: float = 1.0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + if loop_repeats <= 0: + raise ValueError(f"baseline loop_repeats must be positive, got {loop_repeats}") + if loop_delta_scale <= 0.0: + raise ValueError(f"baseline loop_delta_scale must be positive, got {loop_delta_scale}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.loop_repeats = loop_repeats + self.loop_layer_indices = loop_layer_indices + self.loop_delta_scale = loop_delta_scale + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = ( + BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim, bigram_scale_init) + if bigram_vocab_size > 0 + else None + ) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + mlp_hidden_dim, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _apply_block(self, index: int, x: Tensor, x0: Tensor) -> Tensor: + repeats = self.loop_repeats if index in self.loop_layer_indices else 1 + for _ in range(repeats): + next_x = self.blocks[index](x, x0) + if repeats == 1 and self.loop_delta_scale == 1.0: + x = next_x + else: + x = x + (next_x - x) * self.loop_delta_scale + return x + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x = self._apply_block(i, x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block_index = self.num_encoder_layers + i + x = self._apply_block(block_index, x, x0) + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.forward_logits(input_ids).reshape(-1, self.tok_emb.num_embeddings) + targets = target_ids.reshape(-1) + return F.cross_entropy(x.float(), targets, reduction="mean") + + +class P20HybridGPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden_dim: int, + bigram_vocab_size: int, + bigram_dim: int, + bigram_scale_init: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + *, + runtime_backend: str, + layer_schedule: str, + state_transform_blocks: int, + scan_chunk_size: int, + p20_adapter_dim: int, + p20_adapter_scale_init: float, + p20_loop_repeats: int, + loop_delta_scale: float, + p20_primitive_loop_repeats: int, + p20_primitive_loop_repeats_by_p: tuple[int, ...], + p20_primitive_loop_delta_scale: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + if p20_loop_repeats <= 0: + raise ValueError(f"P20_LOOP_REPEATS must be positive, got {p20_loop_repeats}") + if loop_delta_scale <= 0.0: + raise ValueError(f"LOOP_DELTA_SCALE must be positive, got {loop_delta_scale}") + if p20_primitive_loop_repeats <= 0: + raise ValueError( + f"P20_PRIMITIVE_LOOP_REPEATS must be positive, got {p20_primitive_loop_repeats}" + ) + if p20_primitive_loop_delta_scale <= 0.0: + raise ValueError( + "P20_PRIMITIVE_LOOP_DELTA_SCALE must be positive, " + f"got {p20_primitive_loop_delta_scale}" + ) + validate_p20_runtime_backend(runtime_backend) + schedule = normalize_p20_layer_schedule(layer_schedule or default_p20_layer_schedule(num_layers), num_layers) + if len(p20_primitive_loop_repeats_by_p) != schedule.count("P"): + raise ValueError( + "P20_PRIMITIVE_LOOP_REPEATS_BY_P must resolve to one value per primitive layer, " + f"got {p20_primitive_loop_repeats_by_p} for schedule {schedule!r}" + ) + for p_index, repeats in enumerate(p20_primitive_loop_repeats_by_p): + if repeats <= 0: + raise ValueError( + f"P20 primitive loop repeat at P ordinal {p_index} must be positive, got {repeats}" + ) + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.runtime_backend = runtime_backend + self.layer_schedule = schedule + self.state_transform_blocks = state_transform_blocks + self.scan_chunk_size = scan_chunk_size + self.p20_loop_repeats = p20_loop_repeats + self.loop_delta_scale = loop_delta_scale + self.p20_primitive_loop_repeats = p20_primitive_loop_repeats + self.p20_primitive_loop_repeats_by_p = p20_primitive_loop_repeats_by_p + self.p20_primitive_loop_delta_scale = p20_primitive_loop_delta_scale + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = ( + BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim, bigram_scale_init) + if bigram_vocab_size > 0 + else None + ) + blocks = [] + p_index = 0 + for stage in schedule: + if stage == "A": + blocks.append(Block(model_dim, num_heads, num_kv_heads, mlp_mult, mlp_hidden_dim, rope_base, qk_gain_init)) + continue + primitive_loop_repeats = p20_primitive_loop_repeats_by_p[p_index] + p_index += 1 + blocks.append( + P20PrimitiveBlock( + model_dim, + mlp_mult, + mlp_hidden_dim, + runtime_backend=runtime_backend, + state_transform_blocks=state_transform_blocks, + scan_chunk_size=scan_chunk_size, + adapter_dim=p20_adapter_dim, + adapter_scale_init=p20_adapter_scale_init, + loop_repeats=p20_loop_repeats, + loop_delta_scale=loop_delta_scale, + primitive_loop_repeats=primitive_loop_repeats, + primitive_loop_delta_scale=p20_primitive_loop_delta_scale, + ) + ) + self.blocks = nn.ModuleList(blocks) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + for block in self.blocks: + if isinstance(block, Block): + x = block(x, x0) + else: + x = block(x) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.forward_logits(input_ids).reshape(-1, self.tok_emb.num_embeddings) + targets = target_ids.reshape(-1) + return F.cross_entropy(x.float(), targets, reduction="mean") + + +class MambaHybridGPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden_dim: int, + bigram_vocab_size: int, + bigram_dim: int, + bigram_scale_init: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + *, + layer_schedule: str, + d_state: int, + expand: int, + ngroups: int, + chunk_size: int, + is_mimo: bool, + mimo_rank: int, + include_mlp: bool, + residual_scale_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + schedule = normalize_mamba_layer_schedule(layer_schedule or default_mamba_layer_schedule(num_layers), num_layers) + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.layer_schedule = schedule + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = ( + BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim, bigram_scale_init) + if bigram_vocab_size > 0 + else None + ) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + blocks: list[nn.Module] = [] + for stage in schedule: + if stage == "A": + blocks.append(Block(model_dim, num_heads, num_kv_heads, mlp_mult, mlp_hidden_dim, rope_base, qk_gain_init)) + else: + blocks.append( + MambaSequenceBlock( + model_dim, + num_heads, + mlp_mult, + mlp_hidden_dim, + d_state=d_state, + expand=expand, + ngroups=ngroups, + chunk_size=chunk_size, + is_mimo=is_mimo, + mimo_rank=mimo_rank, + include_mlp=include_mlp, + residual_scale_init=residual_scale_init, + ) + ) + self.blocks = nn.ModuleList(blocks) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _apply_block(self, index: int, x: Tensor, x0: Tensor) -> Tensor: + block = self.blocks[index] + return block(x, x0) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x = self._apply_block(i, x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block_index = self.num_encoder_layers + i + x = self._apply_block(block_index, x, x0) + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.forward_logits(input_ids).reshape(-1, self.tok_emb.num_embeddings) + targets = target_ids.reshape(-1) + return F.cross_entropy(x.float(), targets, reduction="mean") + + +def default_recurrence_encoder_schedule(num_layers: int) -> tuple[int, ...]: + return tuple(range(num_layers // 2)) + + +def default_recurrence_decoder_schedule(num_layers: int) -> tuple[int, ...]: + return tuple(range(num_layers // 2, num_layers)) + + +def parse_recurrence_schedule(raw_schedule: str, num_layers: int, name: str, default: tuple[int, ...]) -> tuple[int, ...]: + if not raw_schedule.strip(): + return default + schedule = tuple(int(piece.strip()) for piece in raw_schedule.split(",") if piece.strip()) + if not schedule: + raise ValueError(f"{name} must not be empty") + for index in schedule: + if index < 0 or index >= num_layers: + raise ValueError(f"{name} index {index} is out of range for NUM_LAYERS={num_layers}") + return schedule + + +def parse_recurrence_aux_applications(raw_applications: str, total_applications: int) -> frozenset[int]: + if not raw_applications.strip(): + return frozenset() + applications = frozenset(int(piece.strip()) for piece in raw_applications.split(",") if piece.strip()) + for application in applications: + if application < 0 or application >= total_applications: + raise ValueError( + f"RECURRENCE_AUX_APPLICATIONS entry {application} is out of range " + f"for {total_applications} virtual applications" + ) + return applications + + +def normalize_recurrence_aux_kind(raw_kind: str) -> str: + kind = raw_kind.strip().lower() + if kind not in {"none", "p20", "mamba"}: + raise ValueError(f"RECURRENCE_AUX_KIND must be none|p20|mamba, got {raw_kind!r}") + return kind + + +def normalize_recurrence_aux_mode(raw_mode: str, aux_kind: str) -> str: + mode = raw_mode.strip().lower() + if aux_kind == "none": + if mode not in {"none", ""}: + raise ValueError("RECURRENCE_AUX_MODE must be none when RECURRENCE_AUX_KIND=none") + return "none" + if mode not in {"substitute", "parallel", "pre", "post"}: + raise ValueError( + "RECURRENCE_AUX_MODE must be substitute|parallel|pre|post " + f"when RECURRENCE_AUX_KIND={aux_kind}, got {raw_mode!r}" + ) + return mode + + +class RecurrentHybridGPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden_dim: int, + bigram_vocab_size: int, + bigram_dim: int, + bigram_scale_init: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + *, + encoder_schedule: tuple[int, ...], + decoder_schedule: tuple[int, ...], + aux_kind: str, + aux_mode: str, + aux_applications: frozenset[int], + aux_scale_init: float, + p20_runtime_backend: str, + p20_state_blocks: int, + p20_scan_chunk_size: int, + p20_adapter_dim: int, + p20_adapter_scale_init: float, + p20_primitive_loop_repeats: int, + p20_primitive_loop_delta_scale: float, + mamba_d_state: int, + mamba_expand: int, + mamba_ngroups: int, + mamba_chunk_size: int, + mamba_is_mimo: bool, + mamba_mimo_rank: int, + mamba_include_mlp: bool, + mamba_residual_scale_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + if aux_kind == "none" and aux_applications: + raise ValueError("RECURRENCE_AUX_APPLICATIONS requires RECURRENCE_AUX_KIND=p20|mamba") + if aux_kind != "none" and not aux_applications: + raise ValueError("RECURRENCE_AUX_APPLICATIONS must name at least one virtual application") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.base_encoder_schedule = default_recurrence_encoder_schedule(num_layers) + self.base_decoder_schedule = default_recurrence_decoder_schedule(num_layers) + self.encoder_schedule = encoder_schedule + self.decoder_schedule = decoder_schedule + self.aux_kind = aux_kind + self.aux_mode = aux_mode + self.aux_applications = aux_applications + self.recurrence_active = True + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = ( + BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim, bigram_scale_init) + if bigram_vocab_size > 0 + else None + ) + self.skip_weights = nn.Parameter( + torch.ones(min(len(encoder_schedule), len(decoder_schedule)), model_dim, dtype=torch.float32) + ) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + mlp_hidden_dim, + rope_base, + qk_gain_init, + ) + for _ in range(num_layers) + ] + ) + self.aux_scale = ( + nn.Parameter(torch.full((model_dim,), aux_scale_init, dtype=torch.float32)) + if aux_kind != "none" + else None + ) + if aux_kind == "p20": + validate_p20_runtime_backend(p20_runtime_backend) + self.aux_block: nn.Module | None = P20PrimitiveBlock( + model_dim, + mlp_mult, + mlp_hidden_dim, + runtime_backend=p20_runtime_backend, + state_transform_blocks=p20_state_blocks, + scan_chunk_size=p20_scan_chunk_size, + adapter_dim=p20_adapter_dim, + adapter_scale_init=p20_adapter_scale_init, + loop_repeats=1, + loop_delta_scale=1.0, + primitive_loop_repeats=p20_primitive_loop_repeats, + primitive_loop_delta_scale=p20_primitive_loop_delta_scale, + ) + elif aux_kind == "mamba": + self.aux_block = MambaSequenceBlock( + model_dim, + num_heads, + mlp_mult, + mlp_hidden_dim, + d_state=mamba_d_state, + expand=mamba_expand, + ngroups=mamba_ngroups, + chunk_size=mamba_chunk_size, + is_mimo=mamba_is_mimo, + mimo_rank=mamba_mimo_rank, + include_mlp=mamba_include_mlp, + residual_scale_init=mamba_residual_scale_init, + ) + else: + self.aux_block = None + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def set_recurrence_active(self, active: bool) -> None: + self.recurrence_active = active + + def _apply_aux(self, x: Tensor, x0: Tensor) -> Tensor: + if self.aux_block is None: + raise RuntimeError("auxiliary recurrence block is not configured") + if self.aux_kind == "p20": + return self.aux_block(x) + if self.aux_kind == "mamba": + return self.aux_block(x, x0) + raise RuntimeError(f"unsupported auxiliary recurrence kind {self.aux_kind!r}") + + def _apply_virtual(self, block_index: int, app_index: int, x: Tensor, x0: Tensor) -> Tensor: + block = self.blocks[block_index] + if not self.recurrence_active or app_index not in self.aux_applications: + return block(x, x0) + if self.aux_mode == "substitute": + return self._apply_aux(x, x0) + if self.aux_mode == "pre": + return block(self._apply_aux(x, x0), x0) + if self.aux_mode == "post": + return self._apply_aux(block(x, x0), x0) + if self.aux_mode == "parallel": + if self.aux_scale is None: + raise RuntimeError("parallel auxiliary recurrence requires aux_scale") + block_out = block(x, x0) + aux_out = self._apply_aux(x, x0) + aux_delta = aux_out - x + return block_out + self.aux_scale.to(dtype=x.dtype)[None, None, :] * aux_delta + return block(x, x0) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + encoder_schedule = self.encoder_schedule if self.recurrence_active else self.base_encoder_schedule + decoder_schedule = self.decoder_schedule if self.recurrence_active else self.base_decoder_schedule + + app_index = 0 + for block_index in encoder_schedule: + x = self._apply_virtual(block_index, app_index, x, x0) + skips.append(x) + app_index += 1 + for i, block_index in enumerate(decoder_schedule): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self._apply_virtual(block_index, app_index, x, x0) + app_index += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.forward_logits(input_ids).reshape(-1, self.tok_emb.num_embeddings) + targets = target_ids.reshape(-1) + return F.cross_entropy(x.float(), targets, reduction="mean") + + +def build_model(args: Hyperparameters) -> nn.Module: + common_kwargs = dict( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + mlp_hidden_dim=args.mlp_hidden_dim, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + bigram_scale_init=args.bigram_scale_init, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ) + p20_state_blocks = resolve_p20_state_blocks( + args.p20_state_blocks, + args.model_dim, + args.p20_block_pair_width_cap, + ) + if args.model_family == "baseline": + return GPT( + **common_kwargs, + loop_repeats=args.baseline_loop_repeats, + loop_layer_indices=parse_loop_layer_indices(args.baseline_loop_layer_indices, args.num_layers), + loop_delta_scale=args.loop_delta_scale, + ) + if args.model_family == "recurrent_hybrid": + encoder_schedule = parse_recurrence_schedule( + args.recurrence_encoder_schedule, + args.num_layers, + "RECURRENCE_ENCODER_SCHEDULE", + default_recurrence_encoder_schedule(args.num_layers), + ) + decoder_schedule = parse_recurrence_schedule( + args.recurrence_decoder_schedule, + args.num_layers, + "RECURRENCE_DECODER_SCHEDULE", + default_recurrence_decoder_schedule(args.num_layers), + ) + aux_kind = normalize_recurrence_aux_kind(args.recurrence_aux_kind) + aux_mode = normalize_recurrence_aux_mode(args.recurrence_aux_mode, aux_kind) + return RecurrentHybridGPT( + **common_kwargs, + encoder_schedule=encoder_schedule, + decoder_schedule=decoder_schedule, + aux_kind=aux_kind, + aux_mode=aux_mode, + aux_applications=parse_recurrence_aux_applications( + args.recurrence_aux_applications, + len(encoder_schedule) + len(decoder_schedule), + ), + aux_scale_init=args.recurrence_aux_scale_init, + p20_runtime_backend=args.p20_runtime_backend, + p20_state_blocks=p20_state_blocks, + p20_scan_chunk_size=args.p20_scan_chunk_size, + p20_adapter_dim=args.p20_adapter_dim, + p20_adapter_scale_init=args.p20_adapter_scale_init, + p20_primitive_loop_repeats=args.p20_primitive_loop_repeats, + p20_primitive_loop_delta_scale=args.p20_primitive_loop_delta_scale, + mamba_d_state=args.mamba_d_state, + mamba_expand=args.mamba_expand, + mamba_ngroups=args.mamba_ngroups, + mamba_chunk_size=args.mamba_chunk_size, + mamba_is_mimo=args.mamba_is_mimo, + mamba_mimo_rank=args.mamba_mimo_rank, + mamba_include_mlp=args.mamba_include_mlp, + mamba_residual_scale_init=args.mamba_residual_scale_init, + ) + if args.model_family == "p20_hybrid": + layer_schedule = normalize_p20_layer_schedule( + args.p20_layer_schedule or default_p20_layer_schedule(args.num_layers), + args.num_layers, + ) + p20_primitive_loop_repeats_by_p = parse_p20_primitive_loop_repeats_by_p( + args.p20_primitive_loop_repeats_by_p, + layer_schedule, + args.p20_primitive_loop_repeats, + ) + return P20HybridGPT( + **common_kwargs, + runtime_backend=args.p20_runtime_backend, + layer_schedule=layer_schedule, + state_transform_blocks=p20_state_blocks, + scan_chunk_size=args.p20_scan_chunk_size, + p20_adapter_dim=args.p20_adapter_dim, + p20_adapter_scale_init=args.p20_adapter_scale_init, + p20_loop_repeats=args.p20_loop_repeats, + loop_delta_scale=args.loop_delta_scale, + p20_primitive_loop_repeats=args.p20_primitive_loop_repeats, + p20_primitive_loop_repeats_by_p=p20_primitive_loop_repeats_by_p, + p20_primitive_loop_delta_scale=args.p20_primitive_loop_delta_scale, + ) + if args.model_family == "mamba_hybrid": + return MambaHybridGPT( + **common_kwargs, + layer_schedule=normalize_mamba_layer_schedule( + args.mamba_layer_schedule or default_mamba_layer_schedule(args.num_layers), + args.num_layers, + ), + d_state=args.mamba_d_state, + expand=args.mamba_expand, + ngroups=args.mamba_ngroups, + chunk_size=args.mamba_chunk_size, + is_mimo=args.mamba_is_mimo, + mimo_rank=args.mamba_mimo_rank, + include_mlp=args.mamba_include_mlp, + residual_scale_init=args.mamba_residual_scale_init, + ) + raise ValueError( + f"MODEL_FAMILY must be baseline|recurrent_hybrid|p20_hybrid|mamba_hybrid, got {args.model_family}" + ) + + +def resolve_compile_model(args: Hyperparameters) -> bool: + value = args.compile_model.strip().lower() + if value == "auto": + return args.model_family == "baseline" + if value in {"1", "true", "yes", "on"}: + return True + if value in {"0", "false", "no", "off"}: + return False + raise ValueError(f"COMPILE_MODEL must be auto|0|1, got {args.compile_model}") + + +def resolve_compile_muon_backend(args: Hyperparameters) -> bool: + value = args.compile_muon_backend.strip().lower() + if value == "auto": + return args.model_family != "mamba_hybrid" and not ( + args.model_family == "recurrent_hybrid" + and normalize_recurrence_aux_kind(args.recurrence_aux_kind) == "mamba" + ) + if value in {"1", "true", "yes", "on"}: + return True + if value in {"0", "false", "no", "off"}: + return False + raise ValueError(f"COMPILE_MUON_BACKEND must be auto|0|1, got {args.compile_muon_backend}") + + +def normalize_optimizer_family(value: str) -> str: + normalized = value.strip().lower() + if normalized not in {"muon", "adam"}: + raise ValueError(f"OPTIMIZER_FAMILY must be muon|adam, got {value!r}") + return normalized + + +def default_p20_layer_schedule(num_layers: int) -> str: + return "".join("A" if index % 2 == 0 else "P" for index in range(num_layers)) + + +def default_mamba_layer_schedule(num_layers: int) -> str: + if num_layers <= 0: + raise ValueError(f"NUM_LAYERS must be positive, got {num_layers}") + center = num_layers // 2 + return "".join("M" if index == center else "A" for index in range(num_layers)) + + +def normalize_p20_layer_schedule(raw_schedule: str, num_layers: int) -> str: + schedule = "".join(char for char in raw_schedule.upper() if char in {"A", "P"}) + if len(schedule) != num_layers: + raise ValueError( + f"P20_LAYER_SCHEDULE must contain exactly {num_layers} A/P stages, got {raw_schedule!r}" + ) + if "P" not in schedule: + raise ValueError("P20_LAYER_SCHEDULE must include at least one primitive layer") + return schedule + + +def normalize_mamba_layer_schedule(raw_schedule: str, num_layers: int) -> str: + schedule = "".join(char for char in raw_schedule.upper() if char in {"A", "M"}) + if len(schedule) != num_layers: + raise ValueError( + f"MAMBA_LAYER_SCHEDULE must contain exactly {num_layers} A/M stages, got {raw_schedule!r}" + ) + if "M" not in schedule: + raise ValueError("MAMBA_LAYER_SCHEDULE must include at least one Mamba layer") + return schedule + + +def parse_loop_layer_indices(raw_indices: str, num_layers: int) -> frozenset[int]: + if not raw_indices.strip(): + return frozenset() + indices: set[int] = set() + for raw_piece in raw_indices.split(","): + piece = raw_piece.strip() + if not piece: + continue + index = int(piece) + if index < 0 or index >= num_layers: + raise ValueError( + f"loop layer index {index} is out of range for NUM_LAYERS={num_layers}" + ) + indices.add(index) + return frozenset(indices) + + +def parse_p20_primitive_loop_repeats_by_p( + raw_repeats: str, + layer_schedule: str, + default_repeats: int, +) -> tuple[int, ...]: + p_count = layer_schedule.count("P") + if default_repeats <= 0: + raise ValueError(f"P20_PRIMITIVE_LOOP_REPEATS must be positive, got {default_repeats}") + if not raw_repeats.strip(): + return tuple(default_repeats for _ in range(p_count)) + repeats = tuple(int(piece.strip()) for piece in raw_repeats.split(",") if piece.strip()) + if len(repeats) != p_count: + raise ValueError( + "P20_PRIMITIVE_LOOP_REPEATS_BY_P must contain exactly one comma-separated " + f"value per primitive layer; got {raw_repeats!r} for schedule {layer_schedule!r}" + ) + for p_index, repeat_count in enumerate(repeats): + if repeat_count <= 0: + raise ValueError( + f"P20_PRIMITIVE_LOOP_REPEATS_BY_P value at P ordinal {p_index} " + f"must be positive, got {repeat_count}" + ) + return repeats + + +def resolve_p20_state_blocks(raw_value: str, model_dim: int, pair_width_cap: int) -> int: + value = raw_value.strip().lower() + if value != "auto": + blocks = int(value) + if blocks <= 0: + raise ValueError(f"P20_STATE_BLOCKS must be positive or auto, got {raw_value!r}") + return blocks + if pair_width_cap <= 0: + raise ValueError(f"P20_BLOCK_PAIR_WIDTH_CAP must be positive for auto tiling, got {pair_width_cap}") + for blocks in range(1, model_dim + 1): + if model_dim % blocks != 0: + continue + block_width = model_dim // blocks + if block_width % 2 == 0 and block_width // 2 <= pair_width_cap: + return blocks + raise ValueError( + f"could not resolve P20_STATE_BLOCKS=auto for MODEL_DIM={model_dim} " + f"and P20_BLOCK_PAIR_WIDTH_CAP={pair_width_cap}" + ) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + use_compiled_muon_backend = resolve_compile_muon_backend(args) + if use_compiled_muon_backend: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + summary_path = Path(args.run_summary_path) if args.run_summary_path else out_dir / f"{args.run_id}_summary.json" + resolved_p20_ttt_mode = normalize_p20_ttt_mode(args.p20_ttt_mode) + if resolved_p20_ttt_mode != "off" and args.model_family != "p20_hybrid": + raise ValueError("P20_TTT_MODE requires MODEL_FAMILY=p20_hybrid") + if args.p20_ttt_chunk_size <= 0: + raise ValueError(f"P20_TTT_CHUNK_SIZE must be positive, got {args.p20_ttt_chunk_size}") + if args.p20_ttt_context_size < 0: + raise ValueError(f"P20_TTT_CONTEXT_SIZE must be non-negative, got {args.p20_ttt_context_size}") + if args.p20_ttt_context_size and args.p20_ttt_context_size < args.p20_ttt_chunk_size: + raise ValueError( + "P20_TTT_CONTEXT_SIZE must be 0 or at least P20_TTT_CHUNK_SIZE; " + f"got context={args.p20_ttt_context_size} chunk={args.p20_ttt_chunk_size}" + ) + if args.p20_ttt_bootstrap_samples < 0: + raise ValueError(f"P20_TTT_BOOTSTRAP_SAMPLES must be non-negative, got {args.p20_ttt_bootstrap_samples}") + if args.p20_adapter_dim < 0: + raise ValueError(f"P20_ADAPTER_DIM must be non-negative, got {args.p20_adapter_dim}") + if args.mlp_hidden_dim < 0: + raise ValueError(f"MLP_HIDDEN_DIM must be non-negative, got {args.mlp_hidden_dim}") + if args.bigram_vocab_size < 0: + raise ValueError(f"BIGRAM_VOCAB_SIZE must be non-negative, got {args.bigram_vocab_size}") + if args.bigram_vocab_size == 1: + raise ValueError("BIGRAM_VOCAB_SIZE must be 0 or greater than 1") + if args.bigram_dim <= 0: + raise ValueError(f"BIGRAM_DIM must be positive, got {args.bigram_dim}") + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_file_paths = sorted(dataset_dir.glob("fineweb_train_*.bin")) + actual_train_files = len(actual_train_file_paths) + actual_train_tokens = count_data_shard_tokens(actual_train_file_paths) + if args.min_train_shards > 0 and actual_train_files < args.min_train_shards: + raise ValueError( + f"Corpus preflight failed: found {actual_train_files} train shards in {dataset_dir}, " + f"but MIN_TRAIN_SHARDS={args.min_train_shards}" + ) + if args.min_train_tokens > 0 and actual_train_tokens < args.min_train_tokens: + raise ValueError( + f"Corpus preflight failed: found {actual_train_tokens} train tokens in {dataset_dir}, " + f"but MIN_TRAIN_TOKENS={args.min_train_tokens}" + ) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files} train_tokens:{actual_train_tokens}") + if args.min_train_shards > 0 or args.min_train_tokens > 0: + log0( + f"corpus_preflight:passed min_train_shards:{args.min_train_shards} " + f"min_train_tokens:{args.min_train_tokens}" + ) + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + log0(f"model_family:{args.model_family} p20_runtime_backend:{args.p20_runtime_backend}") + resolved_mlp_hidden_dim = args.mlp_hidden_dim if args.mlp_hidden_dim > 0 else args.mlp_mult * args.model_dim + log0(f"mlp_mult:{args.mlp_mult} mlp_hidden_dim:{resolved_mlp_hidden_dim}") + log0( + f"bigram_vocab_size:{args.bigram_vocab_size} " + f"bigram_dim:{args.bigram_dim} bigram_scale_init:{args.bigram_scale_init}" + ) + resolved_p20_state_blocks = ( + resolve_p20_state_blocks(args.p20_state_blocks, args.model_dim, args.p20_block_pair_width_cap) + if args.model_family == "p20_hybrid" + else None + ) + resolved_p20_layer_schedule = ( + normalize_p20_layer_schedule(args.p20_layer_schedule or default_p20_layer_schedule(args.num_layers), args.num_layers) + if args.model_family == "p20_hybrid" + else None + ) + resolved_p20_primitive_loop_repeats_by_p = ( + parse_p20_primitive_loop_repeats_by_p( + args.p20_primitive_loop_repeats_by_p, + resolved_p20_layer_schedule, + args.p20_primitive_loop_repeats, + ) + if resolved_p20_layer_schedule is not None + else None + ) + resolved_mamba_layer_schedule = ( + normalize_mamba_layer_schedule(args.mamba_layer_schedule or default_mamba_layer_schedule(args.num_layers), args.num_layers) + if args.model_family == "mamba_hybrid" + else None + ) + resolved_recurrence_encoder_schedule: tuple[int, ...] | None = None + resolved_recurrence_decoder_schedule: tuple[int, ...] | None = None + resolved_recurrence_aux_kind: str | None = None + resolved_recurrence_aux_mode: str | None = None + resolved_recurrence_aux_applications: frozenset[int] | None = None + if args.model_family == "recurrent_hybrid": + resolved_recurrence_encoder_schedule = parse_recurrence_schedule( + args.recurrence_encoder_schedule, + args.num_layers, + "RECURRENCE_ENCODER_SCHEDULE", + default_recurrence_encoder_schedule(args.num_layers), + ) + resolved_recurrence_decoder_schedule = parse_recurrence_schedule( + args.recurrence_decoder_schedule, + args.num_layers, + "RECURRENCE_DECODER_SCHEDULE", + default_recurrence_decoder_schedule(args.num_layers), + ) + resolved_recurrence_aux_kind = normalize_recurrence_aux_kind(args.recurrence_aux_kind) + resolved_recurrence_aux_mode = normalize_recurrence_aux_mode( + args.recurrence_aux_mode, + resolved_recurrence_aux_kind, + ) + resolved_recurrence_aux_applications = parse_recurrence_aux_applications( + args.recurrence_aux_applications, + len(resolved_recurrence_encoder_schedule) + len(resolved_recurrence_decoder_schedule), + ) + baseline_loop_layer_indices = ( + parse_loop_layer_indices(args.baseline_loop_layer_indices, args.num_layers) + if args.model_family == "baseline" + else frozenset() + ) + if args.model_family == "p20_hybrid": + log0(f"p20_layer_schedule:{resolved_p20_layer_schedule}") + log0( + f"p20_state_blocks:{resolved_p20_state_blocks} " + f"requested:{args.p20_state_blocks} pair_width_cap:{args.p20_block_pair_width_cap}" + ) + log0(f"p20_scan_chunk_size:{args.p20_scan_chunk_size}") + log0(f"p20_adapter_dim:{args.p20_adapter_dim} p20_adapter_scale_init:{args.p20_adapter_scale_init}") + log0(f"p20_loop_repeats:{args.p20_loop_repeats} loop_delta_scale:{args.loop_delta_scale}") + log0( + f"p20_primitive_loop_repeats:{args.p20_primitive_loop_repeats} " + f"primitive_loop_delta_scale:{args.p20_primitive_loop_delta_scale}" + ) + log0( + "p20_primitive_loop_repeats_by_p:" + f"{list(resolved_p20_primitive_loop_repeats_by_p or [])}" + ) + log0( + f"p20_ttt_mode:{resolved_p20_ttt_mode} " + f"chunk_size:{args.p20_ttt_chunk_size} " + f"context_size:{args.p20_ttt_context_size or args.train_seq_len} " + f"lr:{args.p20_ttt_lr} " + f"betas:({args.p20_ttt_beta1},{args.p20_ttt_beta2}) " + f"weight_decay:{args.p20_ttt_weight_decay} " + f"grad_clip_norm:{args.p20_ttt_grad_clip_norm} " + f"max_docs:{args.p20_ttt_max_docs or 'full'} " + f"bootstrap_samples:{args.p20_ttt_bootstrap_samples}" + ) + if args.model_family == "baseline": + log0( + f"baseline_loop_repeats:{args.baseline_loop_repeats} " + f"baseline_loop_layer_indices:{sorted(baseline_loop_layer_indices)} " + f"loop_delta_scale:{args.loop_delta_scale}" + ) + if args.model_family == "mamba_hybrid": + log0( + f"mamba_layer_schedule:{resolved_mamba_layer_schedule} " + f"d_state:{args.mamba_d_state} expand:{args.mamba_expand} " + f"ngroups:{args.mamba_ngroups} chunk_size:{args.mamba_chunk_size} " + f"is_mimo:{args.mamba_is_mimo} mimo_rank:{args.mamba_mimo_rank} " + f"include_mlp:{args.mamba_include_mlp} residual_scale_init:{args.mamba_residual_scale_init}" + ) + if args.model_family == "recurrent_hybrid": + log0( + f"recurrence_encoder_schedule:{list(resolved_recurrence_encoder_schedule or [])} " + f"recurrence_decoder_schedule:{list(resolved_recurrence_decoder_schedule or [])}" + ) + log0( + f"recurrence_start_fraction:{args.recurrence_start_fraction} " + f"recurrence_start_step:{args.recurrence_start_step} " + f"recurrence_aux_kind:{resolved_recurrence_aux_kind} " + f"recurrence_aux_mode:{resolved_recurrence_aux_mode} " + f"recurrence_aux_applications:{sorted(resolved_recurrence_aux_applications or [])} " + f"recurrence_aux_scale_init:{args.recurrence_aux_scale_init}" + ) + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = build_model(args).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + use_compiled_model = resolve_compile_model(args) + optimizer_family = normalize_optimizer_family(args.optimizer_family) + if use_compiled_model: + compile_fullgraph = args.model_family == "baseline" + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=compile_fullgraph) + else: + compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - non-P20 matrix params use MATRIX_LR via Muon/Adam + # - non-P20 vectors/scalars use SCALAR_LR via Adam + # - P20 params can opt into independent LRs without moving transformer params + special_param_ids = {id(base_model.tok_emb.weight)} + if base_model.lm_head is not None: + special_param_ids.add(id(base_model.lm_head.weight)) + if getattr(base_model, "bigram", None) is not None: + special_param_ids.add(id(base_model.bigram.embed.weight)) + p20_param_ids = { + id(param) + for module in base_model.modules() + if isinstance(module, P20PrimitiveBlock) + for param in module.parameters(recurse=True) + } + model_named_params = [ + (name, p) + for name, p in base_model.named_parameters() + if p.requires_grad and id(p) not in special_param_ids + ] + non_p20_matrix_params = [] + p20_matrix_params = [] + non_p20_scalar_params = [] + p20_scalar_params = [] + for name, p in model_named_params: + is_matrix_owned = p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + is_p20_owned = id(p) in p20_param_ids + if is_matrix_owned and is_p20_owned: + p20_matrix_params.append(p) + elif is_matrix_owned: + non_p20_matrix_params.append(p) + elif is_p20_owned: + p20_scalar_params.append(p) + else: + non_p20_scalar_params.append(p) + p20_matrix_lr = args.p20_matrix_lr if args.p20_matrix_lr is not None else args.matrix_lr + p20_scalar_lr = args.p20_scalar_lr if args.p20_scalar_lr is not None else args.scalar_lr + matrix_param_groups = [] + if non_p20_matrix_params: + matrix_param_groups.append( + {"params": non_p20_matrix_params, "lr": args.matrix_lr, "base_lr": args.matrix_lr, "name": "matrix"} + ) + if p20_matrix_params: + matrix_param_groups.append( + {"params": p20_matrix_params, "lr": p20_matrix_lr, "base_lr": p20_matrix_lr, "name": "p20_matrix"} + ) + scalar_param_groups = [] + if non_p20_scalar_params: + scalar_param_groups.append( + {"params": non_p20_scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr, "name": "scalar"} + ) + if p20_scalar_params: + scalar_param_groups.append( + {"params": p20_scalar_params, "lr": p20_scalar_lr, "base_lr": p20_scalar_lr, "name": "p20_scalar"} + ) + matrix_params = [param for group in matrix_param_groups for param in group["params"]] + scalar_params = [param for group in scalar_param_groups for param in group["params"]] + matrix_param_ids = {id(param) for param in matrix_params} + scalar_param_ids = {id(param) for param in scalar_params} + overlap_count = len(matrix_param_ids & scalar_param_ids) + missing_model_params = [ + (name, tuple(param.shape)) + for name, param in model_named_params + if id(param) not in matrix_param_ids and id(param) not in scalar_param_ids + ] + if overlap_count or missing_model_params: + raise RuntimeError( + "optimizer ownership invariant failed for model parameters: " + f"overlap_count={overlap_count} missing={missing_model_params[:8]}" + ) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + token_params = [base_model.tok_emb.weight] + if getattr(base_model, "bigram", None) is not None: + token_params.append(base_model.bigram.embed.weight) + optimizer_tok = torch.optim.Adam( + [{"params": token_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon: Muon | None = None + if optimizer_family == "muon": + optimizer_muon = Muon( + matrix_param_groups, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + matrix_optimizer: torch.optim.Optimizer = optimizer_muon + else: + matrix_optimizer = torch.optim.Adam( + matrix_param_groups, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_scalar = torch.optim.Adam( + scalar_param_groups, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, matrix_optimizer, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} compiled_model:{use_compiled_model} compiled_muon_backend:{use_compiled_muon_backend}") + log0(f"optimizer_family:{optimizer_family}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"p20_matrix_lr:{p20_matrix_lr if p20_matrix_params else 'n/a'} " + f"p20_scalar_lr:{p20_scalar_lr if p20_scalar_params else 'n/a'}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} val_max_batches:{args.val_max_batches or 'full'} " + f"eval_stride:{args.eval_stride or 'off'} sw_eval_batch:{args.sw_eval_batch} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"compile_prewarm_steps:{args.compile_prewarm_steps} " + f"compile_prewarm_counts_toward_training:{args.compile_prewarm_counts_toward_training} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + synthetic_prewarm_generator = torch.Generator(device=device) + synthetic_prewarm_generator.manual_seed(args.seed + 1_000_003 + rank) + + def synthetic_compile_batch() -> tuple[Tensor, Tensor]: + local_tokens = args.train_batch_tokens // (world_size * grad_accum_steps) + local_batch_seqs = local_tokens // args.train_seq_len + if local_batch_seqs <= 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must provide at least one synthetic prewarm sequence per rank; " + f"got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens = torch.randint( + 0, + args.vocab_size, + (local_batch_seqs, args.train_seq_len + 1), + generator=synthetic_prewarm_generator, + device=device, + dtype=torch.int64, + ) + return tokens[:, :-1].contiguous(), tokens[:, 1:].contiguous() + + def snapshot_execution_state() -> tuple[dict[str, Tensor], list[dict[str, object]], object, tuple[object, ...], Tensor, list[Tensor]]: + return ( + {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()}, + [copy.deepcopy(opt.state_dict()) for opt in optimizers], + random.getstate(), + np.random.get_state(), + torch.random.get_rng_state(), + torch.cuda.get_rng_state_all(), + ) + + def restore_execution_state( + snapshot: tuple[dict[str, Tensor], list[dict[str, object]], object, tuple[object, ...], Tensor, list[Tensor]] + ) -> None: + model_state, optimizer_states, py_rng_state, np_rng_state, torch_rng_state, cuda_rng_states = snapshot + base_model.load_state_dict(model_state, strict=True) + for opt, state in zip(optimizers, optimizer_states, strict=True): + opt.load_state_dict(state) + random.setstate(py_rng_state) + np.random.set_state(np_rng_state) + torch.random.set_rng_state(torch_rng_state) + torch.cuda.set_rng_state_all(cuda_rng_states) + + def run_compile_prewarm(steps: int) -> tuple[int, float]: + if steps <= 0: + return 0, 0.0 + model.train() + torch.cuda.synchronize() + prewarm_t0 = time.perf_counter() + executed = 0 + for prewarm_step in range(steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = synthetic_compile_batch() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + prewarm_loss = model(x, y) + (prewarm_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + executed = prewarm_step + 1 + log0(f"compile_prewarm_step:{executed}/{steps}") + torch.cuda.synchronize() + return executed, 1000.0 * (time.perf_counter() - prewarm_t0) + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + training_time_ms = 0.0 + compile_prewarm_time_ms = 0.0 + compile_prewarm_steps_executed = 0 + if args.compile_prewarm_steps > 0: + if not use_compiled_model: + log0( + f"compile_prewarm_skipped requested:{args.compile_prewarm_steps} " + f"reason:compiled_model_disabled model_family:{args.model_family}" + ) + else: + # This path is intentionally synthetic and state-restored. It can prime + # lazy compiler/optimizer kernels without silently consuming training + # data or changing the real training trajectory. + prewarm_snapshot = snapshot_execution_state() + compile_prewarm_steps_executed, compile_prewarm_time_ms = run_compile_prewarm(args.compile_prewarm_steps) + restore_execution_state(prewarm_snapshot) + if args.compile_prewarm_counts_toward_training: + training_time_ms += compile_prewarm_time_ms + log0( + f"compile_prewarm_complete steps:{compile_prewarm_steps_executed} " + f"time:{compile_prewarm_time_ms:.0f}ms " + f"counts_toward_training:{args.compile_prewarm_counts_toward_training} " + "data:synthetic state_restored:true" + ) + warmup_time_ms = 0.0 + effective_warmup_steps = args.warmup_steps if use_compiled_model else 0 + warmup_steps_executed = 0 + if args.warmup_steps > 0 and effective_warmup_steps == 0: + log0( + f"warmup_skipped requested:{args.warmup_steps} reason:compiled_model_disabled " + f"model_family:{args.model_family}" + ) + + # Warmup primes compiled paths. If we pay for it, it counts toward the training cap. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + torch.cuda.synchronize() + warmup_t0 = time.perf_counter() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + warmup_steps_executed = warmup_step + 1 + if effective_warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == effective_warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + if max_wallclock_ms is not None: + torch.cuda.synchronize() + warmup_elapsed_ms = 1000.0 * (time.perf_counter() - warmup_t0) + if warmup_elapsed_ms >= max_wallclock_ms: + log0( + f"stopping_early: wallclock_cap_during_warmup warmup_time:{warmup_elapsed_ms:.0f}ms " + f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}" + ) + break + torch.cuda.synchronize() + warmup_time_ms = 1000.0 * (time.perf_counter() - warmup_t0) + training_time_ms += warmup_time_ms + log0(f"warmup_complete steps:{warmup_steps_executed} warmup_time:{warmup_time_ms:.0f}ms") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + ema_state: dict[str, Tensor] | None = None + ema_applied = False + ema_eval_time_ms: float | None = None + if args.ema_enabled: + ema_state = { + name: tensor.detach().float().clone() + for name, tensor in base_model.state_dict().items() + if tensor.is_floating_point() + } + log0( + f"ema:enabled decay:{args.ema_decay} start_step:{args.ema_start_step} " + f"tracked_tensors:{len(ema_state)}" + ) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + stop_after_step: int | None = None + if max_wallclock_ms is not None and training_time_ms >= max_wallclock_ms: + stop_after_step = 0 + failure_reason: str | None = None + last_train_loss = float("nan") + val_loss = float("nan") + val_bpb = float("nan") + torch.cuda.synchronize() + t0 = time.perf_counter() + + recurrence_activated_step: int | None = None + recurrence_activated_train_time_ms: float | None = None + + def update_recurrence_active(current_step: int, elapsed_ms: float) -> None: + nonlocal recurrence_activated_step, recurrence_activated_train_time_ms + setter = getattr(base_model, "set_recurrence_active", None) + if setter is None: + return + if args.recurrence_start_step >= 0: + active = current_step >= args.recurrence_start_step + elif args.recurrence_start_fraction <= 0.0: + active = True + elif max_wallclock_ms is not None: + active = elapsed_ms >= max_wallclock_ms * args.recurrence_start_fraction + else: + active = current_step >= math.ceil(args.iterations * args.recurrence_start_fraction) + setter(active) + if active and recurrence_activated_step is None: + recurrence_activated_step = current_step + recurrence_activated_train_time_ms = elapsed_ms + log0( + f"recurrence_activated step:{current_step} " + f"train_time:{elapsed_ms:.0f}ms" + ) + + step = 0 + while True: + loop_elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + update_recurrence_active(step, loop_elapsed_ms) + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + last_train_loss = float(train_loss.item()) + + has_nonfinite_loss = not torch.isfinite(train_loss).item() + if distributed: + nonfinite_tensor = torch.tensor(int(has_nonfinite_loss), device=device) + dist.all_reduce(nonfinite_tensor, op=dist.ReduceOp.MAX) + has_nonfinite_loss = bool(nonfinite_tensor.item()) + if has_nonfinite_loss: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + failure_reason = f"non_finite_train_loss step:{step + 1} train_loss:{last_train_loss}" + log0(f"failed:{failure_reason}") + zero_grad_all() + break + + if optimizer_muon is not None: + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + if ema_state is not None and step >= args.ema_start_step: + with torch.no_grad(): + for name, tensor in base_model.state_dict().items(): + if name in ema_state: + ema_state[name].mul_(args.ema_decay).add_(tensor.detach().float(), alpha=1.0 - args.ema_decay) + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if failure_reason is not None: + if master_process: + summary = { + "run_id": args.run_id, + "status": "failed", + "failure_reason": failure_reason, + "last_train_loss": last_train_loss, + "model_family": args.model_family, + "vocab_size": args.vocab_size, + "num_layers": args.num_layers, + "model_dim": args.model_dim, + "num_heads": args.num_heads, + "num_kv_heads": args.num_kv_heads, + "mlp_mult": args.mlp_mult, + "mlp_hidden_dim": resolved_mlp_hidden_dim, + "bigram_vocab_size": args.bigram_vocab_size, + "bigram_dim": args.bigram_dim if args.bigram_vocab_size > 0 else None, + "bigram_scale_init": args.bigram_scale_init if args.bigram_vocab_size > 0 else None, + "model_parameter_count": n_params, + "mamba_layer_schedule": resolved_mamba_layer_schedule, + "mamba_d_state": args.mamba_d_state if args.model_family == "mamba_hybrid" else None, + "mamba_expand": args.mamba_expand if args.model_family == "mamba_hybrid" else None, + "mamba_ngroups": args.mamba_ngroups if args.model_family == "mamba_hybrid" else None, + "mamba_chunk_size": args.mamba_chunk_size if args.model_family == "mamba_hybrid" else None, + "mamba_is_mimo": args.mamba_is_mimo if args.model_family == "mamba_hybrid" else None, + "mamba_mimo_rank": args.mamba_mimo_rank if args.model_family == "mamba_hybrid" else None, + "mamba_include_mlp": args.mamba_include_mlp if args.model_family == "mamba_hybrid" else None, + "mamba_residual_scale_init": ( + args.mamba_residual_scale_init if args.model_family == "mamba_hybrid" else None + ), + "p20_runtime_backend": args.p20_runtime_backend if args.model_family == "p20_hybrid" else None, + "p20_layer_schedule": ( + normalize_p20_layer_schedule( + args.p20_layer_schedule or default_p20_layer_schedule(args.num_layers), + args.num_layers, + ) + if args.model_family == "p20_hybrid" + else None + ), + "p20_state_blocks": resolved_p20_state_blocks, + "p20_state_blocks_requested": args.p20_state_blocks if args.model_family == "p20_hybrid" else None, + "p20_block_pair_width_cap": args.p20_block_pair_width_cap if args.model_family == "p20_hybrid" else None, + "p20_scan_chunk_size": args.p20_scan_chunk_size if args.model_family == "p20_hybrid" else None, + "p20_adapter_dim": args.p20_adapter_dim if args.model_family == "p20_hybrid" else None, + "p20_adapter_scale_init": ( + args.p20_adapter_scale_init if args.model_family == "p20_hybrid" else None + ), + "p20_loop_repeats": args.p20_loop_repeats if args.model_family == "p20_hybrid" else None, + "p20_primitive_loop_repeats": ( + args.p20_primitive_loop_repeats + if args.model_family == "p20_hybrid" + else None + ), + "p20_primitive_loop_delta_scale": ( + args.p20_primitive_loop_delta_scale + if args.model_family == "p20_hybrid" + else None + ), + "baseline_loop_repeats": args.baseline_loop_repeats if args.model_family == "baseline" else None, + "baseline_loop_layer_indices": ( + sorted(baseline_loop_layer_indices) + if args.model_family == "baseline" + else None + ), + "recurrence_encoder_schedule": ( + list(resolved_recurrence_encoder_schedule) + if resolved_recurrence_encoder_schedule is not None + else None + ), + "recurrence_decoder_schedule": ( + list(resolved_recurrence_decoder_schedule) + if resolved_recurrence_decoder_schedule is not None + else None + ), + "recurrence_start_fraction": ( + args.recurrence_start_fraction if args.model_family == "recurrent_hybrid" else None + ), + "recurrence_start_step": ( + args.recurrence_start_step if args.model_family == "recurrent_hybrid" else None + ), + "recurrence_aux_kind": resolved_recurrence_aux_kind, + "recurrence_aux_mode": resolved_recurrence_aux_mode, + "recurrence_aux_applications": ( + sorted(resolved_recurrence_aux_applications) + if resolved_recurrence_aux_applications is not None + else None + ), + "recurrence_aux_scale_init": ( + args.recurrence_aux_scale_init if args.model_family == "recurrent_hybrid" else None + ), + "recurrence_activated_step": recurrence_activated_step, + "recurrence_activated_train_time_ms": recurrence_activated_train_time_ms, + "loop_delta_scale": args.loop_delta_scale, + "optimizer_family": optimizer_family, + "embed_lr": args.embed_lr, + "head_lr": args.head_lr, + "tied_embed_lr": args.tied_embed_lr, + "matrix_lr": args.matrix_lr, + "scalar_lr": args.scalar_lr, + "p20_matrix_lr": p20_matrix_lr if args.model_family == "p20_hybrid" else None, + "p20_scalar_lr": p20_scalar_lr if args.model_family == "p20_hybrid" else None, + "non_p20_matrix_parameter_count": sum(param.numel() for param in non_p20_matrix_params), + "p20_matrix_parameter_count": sum(param.numel() for param in p20_matrix_params), + "non_p20_scalar_parameter_count": sum(param.numel() for param in non_p20_scalar_params), + "p20_scalar_parameter_count": sum(param.numel() for param in p20_scalar_params), + "beta1": args.beta1, + "beta2": args.beta2, + "compiled_model": use_compiled_model, + "compiled_muon_backend": use_compiled_muon_backend, + "final_step": step, + "iterations": args.iterations, + "world_size": world_size, + "grad_accum_steps": grad_accum_steps, + "train_batch_tokens": args.train_batch_tokens, + "train_shards": actual_train_files, + "train_tokens": actual_train_tokens, + "min_train_shards": args.min_train_shards, + "min_train_tokens": args.min_train_tokens, + "train_seq_len": args.train_seq_len, + "val_batch_size": args.val_batch_size, + "val_max_batches": args.val_max_batches if args.val_max_batches > 0 else None, + "eval_stride": args.eval_stride if args.eval_stride > 0 else None, + "sw_eval_batch": args.sw_eval_batch, + "train_time_ms": training_time_ms, + "compile_prewarm_steps_requested": args.compile_prewarm_steps, + "compile_prewarm_steps_executed": compile_prewarm_steps_executed, + "compile_prewarm_time_ms": compile_prewarm_time_ms, + "compile_prewarm_counts_toward_training": args.compile_prewarm_counts_toward_training, + "compile_prewarm_data": "synthetic" if compile_prewarm_steps_executed else None, + "compile_prewarm_state_restored": bool(compile_prewarm_steps_executed), + "warmup_steps_executed": warmup_steps_executed, + "warmup_time_ms": warmup_time_ms, + "ema_enabled": args.ema_enabled, + "ema_decay": args.ema_decay, + "ema_start_step": args.ema_start_step, + "ema_applied": ema_applied, + "ema_eval_time_ms": ema_eval_time_ms, + "quant_format": args.quant_format, + "quant_compression": args.quant_compression, + "mixed_int6_categories": sorted(args.mixed_int6_categories), + "mixed_int6_clip_quantiles": list(args.mixed_int6_clip_quantiles), + } + write_run_summary(summary_path, summary) + log0(f"run_summary_path:{summary_path}") + log0("run_summary_json:" + json.dumps(summary, sort_keys=True)) + if distributed: + dist.destroy_process_group() + return + + if ema_state is not None: + log0(f"ema:applying decay:{args.ema_decay} tracked_tensors:{len(ema_state)}") + current_state = base_model.state_dict() + averaged_state = dict(current_state) + for name, tensor in ema_state.items(): + averaged_state[name] = tensor.to(dtype=current_state[name].dtype) + base_model.load_state_dict(averaged_state, strict=True) + ema_applied = True + torch.cuda.synchronize() + ema_eval_t0 = time.perf_counter() + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + ema_eval_time_ms = 1000.0 * (time.perf_counter() - ema_eval_t0) + log0(f"ema_eval val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} eval_time:{ema_eval_time_ms:.0f}ms") + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + code_bytes, code_file_entries = collect_submission_code_bytes(Path(__file__)) + raw_model_path = out_dir / f"{args.run_id}_model.pt" + if master_process: + torch.save(base_model.state_dict(), raw_model_path) + model_bytes = raw_model_path.stat().st_size + log0(f"Serialized model: {raw_model_path} bytes:{model_bytes}") + log0(f"Code size: {code_bytes} bytes") + log0("Code files: " + json.dumps(code_file_entries, sort_keys=True)) + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + p20_ttt_pre_val_loss: float | None = None + p20_ttt_pre_val_bpb: float | None = None + p20_ttt_pre_eval_ms: float | None = None + p20_ttt_pre_stats: dict[str, object] | None = None + p20_ttt_pre_paired_stats: dict[str, object] | None = None + p20_doc_pre_val_loss: float | None = None + p20_doc_pre_val_bpb: float | None = None + p20_doc_pre_eval_ms: float | None = None + p20_doc_pre_stats: dict[str, object] | None = None + p20_ttt_post_val_loss: float | None = None + p20_ttt_post_val_bpb: float | None = None + p20_ttt_post_eval_ms: float | None = None + p20_ttt_post_stats: dict[str, object] | None = None + p20_ttt_post_paired_stats: dict[str, object] | None = None + p20_doc_post_val_loss: float | None = None + p20_doc_post_val_bpb: float | None = None + p20_doc_post_eval_ms: float | None = None + p20_doc_post_stats: dict[str, object] | None = None + p20_ttt_per_doc_base_path = Path(args.p20_ttt_per_doc_path) if args.p20_ttt_per_doc_path else None + if resolved_p20_ttt_mode != "off": + if distributed: + dist.barrier() + torch.cuda.synchronize() + t_p20_ttt = time.perf_counter() + pre_pair = eval_val_p20_gate_ttt_paired( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + bos_token_id=int(sp.bos_id()), + bootstrap_samples=args.p20_ttt_bootstrap_samples, + bootstrap_seed=args.p20_ttt_bootstrap_seed, + ) + torch.cuda.synchronize() + p20_ttt_pre_eval_ms = 1000.0 * (time.perf_counter() - t_p20_ttt) + p20_doc_pre_eval_ms = p20_ttt_pre_eval_ms + p20_doc_pre_val_loss = pre_pair.baseline_val_loss + p20_doc_pre_val_bpb = pre_pair.baseline_val_bpb + p20_doc_pre_stats = pre_pair.baseline_stats.__dict__ + p20_ttt_pre_val_loss = pre_pair.ttt_val_loss + p20_ttt_pre_val_bpb = pre_pair.ttt_val_bpb + p20_ttt_pre_stats = pre_pair.ttt_stats.__dict__ + p20_ttt_pre_paired_stats = pre_pair.paired_stats + if master_process and p20_ttt_per_doc_base_path is not None: + suffix = p20_ttt_per_doc_base_path.suffix or ".jsonl" + pre_path = p20_ttt_per_doc_base_path.with_name( + f"{p20_ttt_per_doc_base_path.stem}_pre_quant{suffix}" + ) + write_jsonl(pre_path, [{"quant_stage": "pre_quant", **record} for record in pre_pair.per_doc]) + log0(f"p20_ttt_pre_quant_per_doc_path:{pre_path}") + log0( + f"final_pre_quant_p20_doc_chunks val_loss:{p20_doc_pre_val_loss:.4f} " + f"val_bpb:{p20_doc_pre_val_bpb:.4f} eval_time:{p20_doc_pre_eval_ms:.0f}ms " + f"docs:{pre_pair.baseline_stats.doc_count} chunks:{pre_pair.baseline_stats.chunk_count}" + ) + log0( + f"final_pre_quant_p20_doc_chunks_exact " + f"val_loss:{p20_doc_pre_val_loss:.8f} val_bpb:{p20_doc_pre_val_bpb:.8f}" + ) + log0( + f"final_pre_quant_p20_ttt mode:{resolved_p20_ttt_mode} " + f"val_loss:{p20_ttt_pre_val_loss:.4f} val_bpb:{p20_ttt_pre_val_bpb:.4f} " + f"eval_time:{p20_ttt_pre_eval_ms:.0f}ms docs:{pre_pair.ttt_stats.doc_count} " + f"chunks:{pre_pair.ttt_stats.chunk_count} updates:{pre_pair.ttt_stats.update_count} " + f"delta_bpb:{pre_pair.paired_stats['delta_bpb']:.6f}" + ) + log0( + f"final_pre_quant_p20_ttt_exact mode:{resolved_p20_ttt_mode} " + f"val_loss:{p20_ttt_pre_val_loss:.8f} val_bpb:{p20_ttt_pre_val_bpb:.8f}" + ) + log0("final_pre_quant_p20_ttt_paired_stats:" + json.dumps(pre_pair.paired_stats, sort_keys=True)) + + if args.quant_format in {"int8", "int8_zlib"}: + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_extension = "int8" + quant_payload_bytes_key = "int8_payload_bytes" + elif args.quant_format in {"mixed_int6_clipsearch", "int6_clipsearch", "gptq_lite"}: + quant_obj, quant_stats = quantize_state_dict_mixed_int6_clipsearch( + base_model.state_dict(), + int6_categories=args.mixed_int6_categories, + clip_quantiles=args.mixed_int6_clip_quantiles, + ) + quant_extension = "int6clip" + quant_payload_bytes_key = "quantized_payload_bytes" + else: + raise ValueError( + "QUANT_FORMAT must be int8_zlib|int8|mixed_int6_clipsearch|int6_clipsearch|gptq_lite, " + f"got {args.quant_format!r}" + ) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_quant_payload(quant_raw, args.quant_compression) + quant_raw_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_model.{quant_extension}.{args.quant_compression}.ptz" + if master_process: + with open(quant_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + payload_bytes = int(quant_stats.get(quant_payload_bytes_key, 0)) + ratio = quant_stats["baseline_tensor_bytes"] / max(payload_bytes, 1) + log0( + f"Serialized model {args.quant_format}+{args.quant_compression}: {quant_path} " + f"bytes:{quant_file_bytes} (payload:{payload_bytes} raw_torch:{quant_raw_bytes} " + f"payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_format}+{args.quant_compression}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_quant_payload(quant_blob_disk, args.quant_compression)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_export(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + q_eval_ms = 1000.0 * (time.perf_counter() - t_qeval) + log0( + f"final_quant_roundtrip format:{args.quant_format}+{args.quant_compression} " + f"val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{q_eval_ms:.0f}ms" + ) + log0( + f"final_quant_roundtrip_exact format:{args.quant_format}+{args.quant_compression} " + f"val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}" + ) + + if resolved_p20_ttt_mode != "off": + if distributed: + dist.barrier() + torch.cuda.synchronize() + t_p20_ttt = time.perf_counter() + post_pair = eval_val_p20_gate_ttt_paired( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + bos_token_id=int(sp.bos_id()), + bootstrap_samples=args.p20_ttt_bootstrap_samples, + bootstrap_seed=args.p20_ttt_bootstrap_seed, + ) + torch.cuda.synchronize() + p20_ttt_post_eval_ms = 1000.0 * (time.perf_counter() - t_p20_ttt) + p20_doc_post_eval_ms = p20_ttt_post_eval_ms + p20_doc_post_val_loss = post_pair.baseline_val_loss + p20_doc_post_val_bpb = post_pair.baseline_val_bpb + p20_doc_post_stats = post_pair.baseline_stats.__dict__ + p20_ttt_post_val_loss = post_pair.ttt_val_loss + p20_ttt_post_val_bpb = post_pair.ttt_val_bpb + p20_ttt_post_stats = post_pair.ttt_stats.__dict__ + p20_ttt_post_paired_stats = post_pair.paired_stats + if master_process and p20_ttt_per_doc_base_path is not None: + suffix = p20_ttt_per_doc_base_path.suffix or ".jsonl" + post_path = p20_ttt_per_doc_base_path.with_name( + f"{p20_ttt_per_doc_base_path.stem}_post_quant{suffix}" + ) + write_jsonl(post_path, [{"quant_stage": "post_quant", **record} for record in post_pair.per_doc]) + log0(f"p20_ttt_post_quant_per_doc_path:{post_path}") + log0( + f"final_post_quant_p20_doc_chunks val_loss:{p20_doc_post_val_loss:.4f} " + f"val_bpb:{p20_doc_post_val_bpb:.4f} eval_time:{p20_doc_post_eval_ms:.0f}ms " + f"docs:{post_pair.baseline_stats.doc_count} chunks:{post_pair.baseline_stats.chunk_count}" + ) + log0( + f"final_post_quant_p20_doc_chunks_exact " + f"val_loss:{p20_doc_post_val_loss:.8f} val_bpb:{p20_doc_post_val_bpb:.8f}" + ) + log0( + f"final_post_quant_p20_ttt mode:{resolved_p20_ttt_mode} " + f"val_loss:{p20_ttt_post_val_loss:.4f} val_bpb:{p20_ttt_post_val_bpb:.4f} " + f"eval_time:{p20_ttt_post_eval_ms:.0f}ms docs:{post_pair.ttt_stats.doc_count} " + f"chunks:{post_pair.ttt_stats.chunk_count} updates:{post_pair.ttt_stats.update_count} " + f"delta_bpb:{post_pair.paired_stats['delta_bpb']:.6f}" + ) + log0( + f"final_post_quant_p20_ttt_exact mode:{resolved_p20_ttt_mode} " + f"val_loss:{p20_ttt_post_val_loss:.8f} val_bpb:{p20_ttt_post_val_bpb:.8f}" + ) + log0("final_post_quant_p20_ttt_paired_stats:" + json.dumps(post_pair.paired_stats, sort_keys=True)) + + sw_pre_val_loss: float | None = None + sw_pre_val_bpb: float | None = None + sw_pre_eval_ms: float | None = None + sw_post_val_loss: float | None = None + sw_post_val_bpb: float | None = None + sw_post_eval_ms: float | None = None + if args.eval_stride > 0: + if distributed: + dist.barrier() + base_model.load_state_dict(torch.load(raw_model_path, map_location="cpu"), strict=True) + restore_low_dim_params_to_fp32(base_model) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_pre_val_loss, sw_pre_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + sw_pre_eval_ms = 1000.0 * (time.perf_counter() - t_sw) + log0( + f"final_pre_quant_sliding_window stride:{args.eval_stride} " + f"val_loss:{sw_pre_val_loss:.4f} val_bpb:{sw_pre_val_bpb:.4f} eval_time:{sw_pre_eval_ms:.0f}ms" + ) + log0( + f"final_pre_quant_sliding_window_exact stride:{args.eval_stride} " + f"val_loss:{sw_pre_val_loss:.8f} val_bpb:{sw_pre_val_bpb:.8f}" + ) + + base_model.load_state_dict(dequantize_state_dict_export(quant_state), strict=True) + restore_low_dim_params_to_fp32(base_model) + torch.cuda.synchronize() + t_swq = time.perf_counter() + sw_post_val_loss, sw_post_val_bpb = eval_val_sliding_window( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + sw_post_eval_ms = 1000.0 * (time.perf_counter() - t_swq) + log0( + f"final_post_quant_sliding_window stride:{args.eval_stride} " + f"val_loss:{sw_post_val_loss:.4f} val_bpb:{sw_post_val_bpb:.4f} eval_time:{sw_post_eval_ms:.0f}ms" + ) + log0( + f"final_post_quant_sliding_window_exact stride:{args.eval_stride} " + f"val_loss:{sw_post_val_loss:.8f} val_bpb:{sw_post_val_bpb:.8f}" + ) + + if master_process: + raw_model_bytes = raw_model_path.stat().st_size + quant_file_bytes = quant_path.stat().st_size + summary = { + "run_id": args.run_id, + "model_family": args.model_family, + "vocab_size": args.vocab_size, + "num_layers": args.num_layers, + "model_dim": args.model_dim, + "num_heads": args.num_heads, + "num_kv_heads": args.num_kv_heads, + "mlp_mult": args.mlp_mult, + "mlp_hidden_dim": resolved_mlp_hidden_dim, + "bigram_vocab_size": args.bigram_vocab_size, + "bigram_dim": args.bigram_dim if args.bigram_vocab_size > 0 else None, + "bigram_scale_init": args.bigram_scale_init if args.bigram_vocab_size > 0 else None, + "model_parameter_count": n_params, + "mamba_layer_schedule": resolved_mamba_layer_schedule, + "mamba_d_state": args.mamba_d_state if args.model_family == "mamba_hybrid" else None, + "mamba_expand": args.mamba_expand if args.model_family == "mamba_hybrid" else None, + "mamba_ngroups": args.mamba_ngroups if args.model_family == "mamba_hybrid" else None, + "mamba_chunk_size": args.mamba_chunk_size if args.model_family == "mamba_hybrid" else None, + "mamba_is_mimo": args.mamba_is_mimo if args.model_family == "mamba_hybrid" else None, + "mamba_mimo_rank": args.mamba_mimo_rank if args.model_family == "mamba_hybrid" else None, + "mamba_include_mlp": args.mamba_include_mlp if args.model_family == "mamba_hybrid" else None, + "mamba_residual_scale_init": ( + args.mamba_residual_scale_init if args.model_family == "mamba_hybrid" else None + ), + "p20_runtime_backend": args.p20_runtime_backend if args.model_family == "p20_hybrid" else None, + "p20_layer_schedule": resolved_p20_layer_schedule, + "p20_state_blocks": resolved_p20_state_blocks, + "p20_state_blocks_requested": args.p20_state_blocks if args.model_family == "p20_hybrid" else None, + "p20_block_pair_width_cap": args.p20_block_pair_width_cap if args.model_family == "p20_hybrid" else None, + "p20_scan_chunk_size": args.p20_scan_chunk_size if args.model_family == "p20_hybrid" else None, + "p20_adapter_dim": args.p20_adapter_dim if args.model_family == "p20_hybrid" else None, + "p20_adapter_scale_init": ( + args.p20_adapter_scale_init if args.model_family == "p20_hybrid" else None + ), + "p20_loop_repeats": args.p20_loop_repeats if args.model_family == "p20_hybrid" else None, + "p20_primitive_loop_repeats": ( + args.p20_primitive_loop_repeats + if args.model_family == "p20_hybrid" + else None + ), + "p20_primitive_loop_repeats_by_p": ( + list(resolved_p20_primitive_loop_repeats_by_p) + if resolved_p20_primitive_loop_repeats_by_p is not None + else None + ), + "p20_primitive_loop_delta_scale": ( + args.p20_primitive_loop_delta_scale + if args.model_family == "p20_hybrid" + else None + ), + "baseline_loop_repeats": args.baseline_loop_repeats if args.model_family == "baseline" else None, + "baseline_loop_layer_indices": ( + sorted(baseline_loop_layer_indices) + if args.model_family == "baseline" + else None + ), + "recurrence_encoder_schedule": ( + list(resolved_recurrence_encoder_schedule) + if resolved_recurrence_encoder_schedule is not None + else None + ), + "recurrence_decoder_schedule": ( + list(resolved_recurrence_decoder_schedule) + if resolved_recurrence_decoder_schedule is not None + else None + ), + "recurrence_start_fraction": ( + args.recurrence_start_fraction if args.model_family == "recurrent_hybrid" else None + ), + "recurrence_start_step": ( + args.recurrence_start_step if args.model_family == "recurrent_hybrid" else None + ), + "recurrence_aux_kind": resolved_recurrence_aux_kind, + "recurrence_aux_mode": resolved_recurrence_aux_mode, + "recurrence_aux_applications": ( + sorted(resolved_recurrence_aux_applications) + if resolved_recurrence_aux_applications is not None + else None + ), + "recurrence_aux_scale_init": ( + args.recurrence_aux_scale_init if args.model_family == "recurrent_hybrid" else None + ), + "recurrence_activated_step": recurrence_activated_step, + "recurrence_activated_train_time_ms": recurrence_activated_train_time_ms, + "loop_delta_scale": args.loop_delta_scale, + "optimizer_family": optimizer_family, + "embed_lr": args.embed_lr, + "head_lr": args.head_lr, + "tied_embed_lr": args.tied_embed_lr, + "matrix_lr": args.matrix_lr, + "scalar_lr": args.scalar_lr, + "p20_matrix_lr": p20_matrix_lr if args.model_family == "p20_hybrid" else None, + "p20_scalar_lr": p20_scalar_lr if args.model_family == "p20_hybrid" else None, + "non_p20_matrix_parameter_count": sum(param.numel() for param in non_p20_matrix_params), + "p20_matrix_parameter_count": sum(param.numel() for param in p20_matrix_params), + "non_p20_scalar_parameter_count": sum(param.numel() for param in non_p20_scalar_params), + "p20_scalar_parameter_count": sum(param.numel() for param in p20_scalar_params), + "beta1": args.beta1, + "beta2": args.beta2, + "compiled_model": use_compiled_model, + "compiled_muon_backend": use_compiled_muon_backend, + "final_step": step, + "iterations": args.iterations, + "world_size": world_size, + "grad_accum_steps": grad_accum_steps, + "stopped_early": bool(stop_after_step is not None and step < args.iterations), + "train_batch_tokens": args.train_batch_tokens, + "train_shards": actual_train_files, + "train_tokens": actual_train_tokens, + "min_train_shards": args.min_train_shards, + "min_train_tokens": args.min_train_tokens, + "train_seq_len": args.train_seq_len, + "val_batch_size": args.val_batch_size, + "val_max_batches": args.val_max_batches if args.val_max_batches > 0 else None, + "eval_stride": args.eval_stride if args.eval_stride > 0 else None, + "sw_eval_batch": args.sw_eval_batch, + "train_time_ms": training_time_ms, + "compile_prewarm_steps_requested": args.compile_prewarm_steps, + "compile_prewarm_steps_executed": compile_prewarm_steps_executed, + "compile_prewarm_time_ms": compile_prewarm_time_ms, + "compile_prewarm_counts_toward_training": args.compile_prewarm_counts_toward_training, + "compile_prewarm_data": "synthetic" if compile_prewarm_steps_executed else None, + "compile_prewarm_state_restored": bool(compile_prewarm_steps_executed), + "warmup_steps_executed": warmup_steps_executed, + "warmup_time_ms": warmup_time_ms, + "ema_enabled": args.ema_enabled, + "ema_decay": args.ema_decay, + "ema_start_step": args.ema_start_step, + "ema_applied": ema_applied, + "ema_eval_time_ms": ema_eval_time_ms, + "quant_format": args.quant_format, + "quant_compression": args.quant_compression, + "mixed_int6_categories": sorted(args.mixed_int6_categories), + "mixed_int6_clip_quantiles": list(args.mixed_int6_clip_quantiles), + "quant_raw_bytes": quant_raw_bytes, + "quant_file_bytes": quant_file_bytes, + "quant_stats": quant_stats, + "code_bytes": code_bytes, + "code_files": code_file_entries, + "raw_model_bytes": raw_model_bytes, + "int8_zlib_bytes": quant_file_bytes, + "total_submission_bytes_raw": raw_model_bytes + code_bytes, + "total_submission_bytes_quantized": quant_file_bytes + code_bytes, + "total_submission_bytes_int8_zlib": quant_file_bytes + code_bytes, + "pre_quant_val_loss": val_loss, + "pre_quant_val_bpb": val_bpb, + "pre_quant_sliding_eval_time_ms": sw_pre_eval_ms, + "pre_quant_sliding_val_loss": sw_pre_val_loss, + "pre_quant_sliding_val_bpb": sw_pre_val_bpb, + "post_quant_eval_time_ms": q_eval_ms, + "post_quant_val_loss": q_val_loss, + "post_quant_val_bpb": q_val_bpb, + "p20_ttt_mode": resolved_p20_ttt_mode, + "p20_ttt_chunk_size": args.p20_ttt_chunk_size, + "p20_ttt_context_size": args.p20_ttt_context_size if args.p20_ttt_context_size > 0 else args.train_seq_len, + "p20_ttt_lr": args.p20_ttt_lr, + "p20_ttt_beta1": args.p20_ttt_beta1, + "p20_ttt_beta2": args.p20_ttt_beta2, + "p20_ttt_weight_decay": args.p20_ttt_weight_decay, + "p20_ttt_grad_clip_norm": args.p20_ttt_grad_clip_norm, + "p20_ttt_max_docs": args.p20_ttt_max_docs if args.p20_ttt_max_docs > 0 else None, + "p20_ttt_bootstrap_samples": args.p20_ttt_bootstrap_samples, + "p20_ttt_bootstrap_seed": args.p20_ttt_bootstrap_seed, + "p20_ttt_per_doc_path": args.p20_ttt_per_doc_path or None, + "pre_quant_p20_doc_chunks_eval_time_ms": p20_doc_pre_eval_ms, + "pre_quant_p20_doc_chunks_val_loss": p20_doc_pre_val_loss, + "pre_quant_p20_doc_chunks_val_bpb": p20_doc_pre_val_bpb, + "pre_quant_p20_doc_chunks_stats": p20_doc_pre_stats, + "pre_quant_p20_ttt_eval_time_ms": p20_ttt_pre_eval_ms, + "pre_quant_p20_ttt_val_loss": p20_ttt_pre_val_loss, + "pre_quant_p20_ttt_val_bpb": p20_ttt_pre_val_bpb, + "pre_quant_p20_ttt_stats": p20_ttt_pre_stats, + "pre_quant_p20_ttt_paired_stats": p20_ttt_pre_paired_stats, + "post_quant_p20_doc_chunks_eval_time_ms": p20_doc_post_eval_ms, + "post_quant_p20_doc_chunks_val_loss": p20_doc_post_val_loss, + "post_quant_p20_doc_chunks_val_bpb": p20_doc_post_val_bpb, + "post_quant_p20_doc_chunks_stats": p20_doc_post_stats, + "post_quant_p20_ttt_eval_time_ms": p20_ttt_post_eval_ms, + "post_quant_p20_ttt_val_loss": p20_ttt_post_val_loss, + "post_quant_p20_ttt_val_bpb": p20_ttt_post_val_bpb, + "post_quant_p20_ttt_stats": p20_ttt_post_stats, + "post_quant_p20_ttt_paired_stats": p20_ttt_post_paired_stats, + "post_quant_sliding_eval_time_ms": sw_post_eval_ms, + "post_quant_sliding_val_loss": sw_post_val_loss, + "post_quant_sliding_val_bpb": sw_post_val_bpb, + } + write_run_summary(summary_path, summary) + log0(f"run_summary_path:{summary_path}") + log0("run_summary_json:" + json.dumps(summary, sort_keys=True)) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()