Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Cautious Muon + SP4096 + Depth Recurrence + Parallel Residuals

**val_bpb = 1.1604** (3-seed mean, std = 0.0033)

Comment on lines +1 to +4
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This submission is labeled as a “Record” (folder/PR title), but the reported mean val_bpb = 1.1604 is substantially worse than the current 10min_16mb leaderboard entries (e.g. 1.1228 in the repo README). Consider renaming the PR/folder/README name to avoid implying it’s a new SOTA record if it’s intended as a non-record/ablation submission.

Copilot uses AI. Check for mistakes.
## Results

| Seed | val_bpb | val_loss | Artifact Size |
|------|---------|----------|---------------|
| 42 | 1.1568 | 2.6619 | 15,179,504 B |
| 314 | 1.1611 | 2.6717 | 15,173,470 B |
| 999 | 1.1634 | 2.6770 | 15,159,223 B |
Comment on lines +9 to +13
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results table rows start with double pipes (|| ...), which GitHub Markdown renders as an extra empty first column. Use single leading | for each row so the table displays correctly.

Copilot uses AI. Check for mistakes.
| **Mean** | **1.1604** | **2.6702** | **15,170,732 B** |

## Key Technique: Cautious Muon (arXiv:2411.16085)

The primary modification is applying the Cautious optimizer principle to the Muon optimizer. After Newton-Schulz orthogonalization and MuonEq-R row normalization, the update is masked to only apply where the orthogonalized direction agrees with the raw gradient sign:

```python
caution_mask = (g * raw_grad > 0).to(g.dtype)
g = g * caution_mask / caution_mask.mean().clamp_min(1e-3)
```
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description of Cautious Muon says the mask is applied “after Newton–Schulz orthogonalization and MuonEq-R row normalization”, but later the stack list claims MuonEq-R happens before Newton–Schulz. Please make the ordering consistent with the actual implementation to avoid confusion when reproducing.

Copilot uses AI. Check for mistakes.

This filters out "stale" momentum directions that disagree with the current gradient, providing ~1.47x effective convergence per step with zero parameter overhead and no impact on artifact size.

## Full Architecture Stack

Built on PR #1334 (aryanbhosale) with:
- **SP4096 BPE tokenizer** (from PR #1218, @clarkkev)
- **Depth recurrence** layers 4,5 (13 virtual layers from 11 physical, activated at step 3000)
- **Parallel residuals** from layer 7 (separate attn/MLP lanes with learnable merge)
- **MuonEq-R** row normalization before Newton-Schulz (arXiv:2603.28254)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bullet claims “MuonEq-R row normalization before Newton-Schulz”, which contradicts the earlier description of the masking point. Please align the ordering here with the actual optimizer pipeline (and with the earlier section) so readers don’t implement the wrong sequence.

Suggested change
- **MuonEq-R** row normalization before Newton-Schulz (arXiv:2603.28254)
- **MuonEq-R** row normalization after Newton-Schulz orthogonalization (arXiv:2603.28254)

Copilot uses AI. Check for mistakes.
- **QK-Gain 5.0** per-head query-key scaling
- **EMA 0.997** weight averaging
- **Full GPTQ INT6** quantization with selective +-1 pruning
- **Brotli compression**

## Non-matrix parameters

Token embeddings, scalar parameters, and head use standard `torch.optim.AdamW`. Cautious masking is applied only inside Muon for matrix parameters.

## Compliance

- Track A fixed predictor -- no TTT, no SLOT, no eval-time adaptation
- All predictions are causal and normalized via softmax (F.cross_entropy)
- Artifact under 16MB limit (max 15,179,504 bytes)
- Training completes within 600s wallclock on 8xH100 SXM

## Reproduction

```bash
cd /workspace/parameter-golf
# Download SP4096 data
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp4096
# Run
SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Credits

- PR #1334 (@aryanbhosale) -- base architecture (SP4096, depth recurrence, parallel residuals, MuonEq-R)
- PR #1218 (@clarkkev) -- SP4096 tokenizer
- Liang et al. (arXiv:2411.16085) -- Cautious Optimizers
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"author": "X-Abhishek-X",
"github_id": "X-Abhishek-X",
"name": "Cautious Muon + SP4096 + Depth Recurrence + Parallel Residuals",
"blurb": "Applies Cautious Muon (arXiv:2411.16085) to the Muon optimizer — masks Newton-Schulz updates where the orthogonalized direction disagrees with the raw gradient sign, providing ~1.47x effective convergence per step with zero parameter overhead. Built on PR #1334 (aryanbhosale) base with SP4096 vocabulary, depth recurrence (layers 4,5), parallel residuals (from layer 7), MuonEq-R, QK-Gain 5.0, and full GPTQ INT6 + Brotli compression. Mean val_bpb = 1.1604 (3 seeds, std = 0.0033).",
"date": "2026-04-05",
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The date field is just YYYY-MM-DD, while other submissions typically use an ISO-8601 timestamp (e.g. 2026-03-22T00:00:00Z). Using a consistent timestamp format helps automated consumers parse and sort submissions reliably.

Suggested change
"date": "2026-04-05",
"date": "2026-04-05T00:00:00Z",

Copilot uses AI. Check for mistakes.
"track": "10min_16mb",
"val_loss": 2.67020395,
"val_bpb": 1.16043988,
"val_loss_std": 0.00764948,
"val_bpb_std": 0.00332438,
"seeds": [42, 314, 999],
Comment on lines +8 to +12
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most existing 10min_16mb submission.json files include a bytes_total field (e.g., records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/submission.json:9). This submission uses artifact_bytes_* instead; consider adding bytes_total (and optionally bytes_code) for consistency and easier leaderboard/tooling ingestion.

Copilot uses AI. Check for mistakes.
"seed_results": {
"42": {
"val_loss": 2.66190192,
"val_bpb": 1.15683191,
"artifact_bytes": 15179504
},
"314": {
"val_loss": 2.67174312,
"val_bpb": 1.16110878,
"artifact_bytes": 15173470
},
"999": {
"val_loss": 2.67696681,
"val_bpb": 1.16337894,
"artifact_bytes": 15159223
}
},
"artifact_bytes_mean": 15170732,
"artifact_bytes_max": 15179504,
"hardware": "8x H100 SXM (RunPod On-Demand)",
"pytorch_version": "2.9.1",
"cuda_version": "12.8",
"technique_summary": "Cautious Muon optimizer (arXiv:2411.16085), SP4096 BPE tokenizer, depth recurrence layers 4-5 (start step 3000), parallel residuals from layer 7, MuonEq-R row normalization, QK-Gain 5.0, EMA 0.997, full GPTQ INT6 quantization with selective pruning, Brotli compression",
"comparison_baseline_pr": 1334
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
W0405 12:33:07.793000 58424 torch/distributed/run.py:803]
W0405 12:33:07.793000 58424 torch/distributed/run.py:803] *****************************************
W0405 12:33:07.793000 58424 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0405 12:33:07.793000 58424 torch/distributed/run.py:803] *****************************************
Hyperparameters:
adam_eps: 1e-08
adam_wd: 0.02
beta1: 0.9
beta2: 0.95
cautious_muon: True
compressor: brotli
data_dir: ./data/
datasets_dir: ./data/datasets/fineweb10B_sp4096
distributed: True
ema_decay: 0.997
embed_lr: 0.6
embed_wd: 0.09
embedding_dim: 512
eval_seq_len: 2048
eval_stride: 64
gptq_calibration_batches: 64
gptq_enabled: True
gptq_reserve_seconds: 10.0
grad_accum_steps: 1
grad_clip_norm: 0.3
head_lr: 0.008
is_main_process: True
iterations: 20000
ln_scale: True
local_rank: 0
logfile: logs/523d357e-1519-45f9-bc20-69bbfb520c1b.txt
logit_softcap: 30.0
matrix_lr: 0.02
max_wallclock_seconds: 600.0
min_lr: 0.0
mlp_mult: 4.0
model_dim: 512
model_path: final_model.pt
muon_backend_steps: 5
muon_beta2: 0.95
muon_momentum: 0.99
muon_momentum_warmup_start: 0.92
muon_momentum_warmup_steps: 1500
muon_wd: 0.09
num_heads: 8
num_kv_heads: 4
num_layers: 11
parallel_start_layer: 7
qk_gain_init: 5.0
quantized_model_path: final_model.int6.ptz
rank: 0
recur_layers: 4,5
recur_start_step: 3000
rope_base: 10000.0
rope_dims: 16
rope_train_seq_len: 2048
run_id: 523d357e-1519-45f9-bc20-69bbfb520c1b
scalar_lr: 0.02
seed: 314
skip_gates_enabled: True
sliding_window_enabled: True
tie_embeddings: True
tied_embed_init_std: 0.005
tied_embed_lr: 0.03
tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model
train_batch_tokens: 786432
train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin
train_log_every: 500
train_seq_len: 2048
ttt_batch_seqs: 32
ttt_chunk_tokens: 32768
ttt_enabled: False
ttt_epochs: 3
ttt_freeze_blocks: 0
ttt_grad_clip: 1.0
ttt_lr: 0.002
ttt_momentum: 0.9
val_batch_tokens: 524288
val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin
val_loss_every: 4000
ve_dim: 128
ve_enabled: True
ve_layers: 9,10
vocab_size: 4096
warmdown_frac: 0.667
warmup_steps: 20
world_size: 8
xsa_last_n: 11
train_shards: 80
val_tokens: 45508608
model_params:34401372
gptq:reserving 10s, effective=590000ms
warmup_step: 1/20
warmup_step: 2/20
warmup_step: 3/20
warmup_step: 4/20
warmup_step: 5/20
warmup_step: 6/20
warmup_step: 10/20
warmup_step: 20/20
0/20000 val_loss: 8.3172 val_bpb: 3.6146
1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8417806
2/20000 train_loss: 11.4353 train_time: 0.0m tok/s: 8294705
3/20000 train_loss: 8.8870 train_time: 0.0m tok/s: 8200165
4/20000 train_loss: 8.0296 train_time: 0.0m tok/s: 8159612
5/20000 train_loss: 8.6996 train_time: 0.0m tok/s: 8141617
500/20000 train_loss: 3.0633 train_time: 0.8m tok/s: 7939661
1000/20000 train_loss: 2.9869 train_time: 1.7m tok/s: 7911553
1500/20000 train_loss: 2.9710 train_time: 2.5m tok/s: 7908791
2000/20000 train_loss: 2.6991 train_time: 3.3m tok/s: 7909908
2500/20000 train_loss: 2.7380 train_time: 4.1m tok/s: 7911617
3000/20000 train_loss: 2.7805 train_time: 5.0m tok/s: 7913690
recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10]
3500/20000 train_loss: 2.6930 train_time: 6.1m tok/s: 7516407
4000/20000 train_loss: 2.6237 train_time: 7.1m tok/s: 7424228
4000/20000 val_loss: 2.6465 val_bpb: 1.1501
4500/20000 train_loss: 2.5745 train_time: 8.0m tok/s: 7355829
5000/20000 train_loss: 2.5200 train_time: 9.0m tok/s: 7301867
5449/20000 val_loss: 2.5373 val_bpb: 1.1027
stopping_early: wallclock_cap train_time: 590030ms step: 5449/20000
peak memory allocated: 30120 MiB reserved: 30154 MiB
ema:applying EMA weights
pre-quantization post-ema val_loss:2.53590710 val_bpb:1.10207601 eval_time:2003ms
Serialized model: 132406149 bytes
Code size: 24659 bytes
GPTQ:collecting Hessians from calibration data...
GPTQ:collected 66 Hessians in 9.7s
GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search
selective_prune: unpruned=16.59MB target=16.0MB
selective_prune: pruning 4714416/8989294 lowest-error ±1 values (excess=589302B)
Serialized model int6+brotli: 15148811 bytes
Total submission size int6+brotli: 15173470 bytes
final_int6_roundtrip val_loss:2.71751254 val_bpb:1.18099965 eval_time:8188ms
final_int6_sliding_window val_loss:2.67174312 val_bpb:1.16110878 eval_time:76629ms
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
W0405 12:16:26.757000 47726 torch/distributed/run.py:803]
W0405 12:16:26.757000 47726 torch/distributed/run.py:803] *****************************************
W0405 12:16:26.757000 47726 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0405 12:16:26.757000 47726 torch/distributed/run.py:803] *****************************************
Hyperparameters:
adam_eps: 1e-08
adam_wd: 0.02
beta1: 0.9
beta2: 0.95
cautious_muon: True
compressor: brotli
data_dir: ./data/
datasets_dir: ./data/datasets/fineweb10B_sp4096
distributed: True
ema_decay: 0.997
embed_lr: 0.6
embed_wd: 0.09
embedding_dim: 512
eval_seq_len: 2048
eval_stride: 64
gptq_calibration_batches: 64
gptq_enabled: True
gptq_reserve_seconds: 10.0
grad_accum_steps: 1
grad_clip_norm: 0.3
head_lr: 0.008
is_main_process: True
iterations: 20000
ln_scale: True
local_rank: 0
logfile: logs/6e35fe72-b1cf-49e5-95dc-2b7d967c8075.txt
logit_softcap: 30.0
matrix_lr: 0.02
max_wallclock_seconds: 600.0
min_lr: 0.0
mlp_mult: 4.0
model_dim: 512
model_path: final_model.pt
muon_backend_steps: 5
muon_beta2: 0.95
muon_momentum: 0.99
muon_momentum_warmup_start: 0.92
muon_momentum_warmup_steps: 1500
muon_wd: 0.09
num_heads: 8
num_kv_heads: 4
num_layers: 11
parallel_start_layer: 7
qk_gain_init: 5.0
quantized_model_path: final_model.int6.ptz
rank: 0
recur_layers: 4,5
recur_start_step: 3000
rope_base: 10000.0
rope_dims: 16
rope_train_seq_len: 2048
run_id: 6e35fe72-b1cf-49e5-95dc-2b7d967c8075
scalar_lr: 0.02
seed: 42
skip_gates_enabled: True
sliding_window_enabled: True
tie_embeddings: True
tied_embed_init_std: 0.005
tied_embed_lr: 0.03
tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model
train_batch_tokens: 786432
train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin
train_log_every: 500
train_seq_len: 2048
ttt_batch_seqs: 32
ttt_chunk_tokens: 32768
ttt_enabled: False
ttt_epochs: 3
ttt_freeze_blocks: 0
ttt_grad_clip: 1.0
ttt_lr: 0.002
ttt_momentum: 0.9
val_batch_tokens: 524288
val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin
val_loss_every: 4000
ve_dim: 128
ve_enabled: True
ve_layers: 9,10
vocab_size: 4096
warmdown_frac: 0.667
warmup_steps: 20
world_size: 8
xsa_last_n: 11
train_shards: 80
val_tokens: 45508608
model_params:34401372
gptq:reserving 10s, effective=590000ms
warmup_step: 1/20
warmup_step: 2/20
warmup_step: 3/20
warmup_step: 4/20
warmup_step: 5/20
warmup_step: 6/20
warmup_step: 10/20
warmup_step: 20/20
0/20000 val_loss: 8.3187 val_bpb: 3.6152
1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8409702
2/20000 train_loss: 11.3749 train_time: 0.0m tok/s: 8301223
3/20000 train_loss: 8.9270 train_time: 0.0m tok/s: 8232435
4/20000 train_loss: 8.0370 train_time: 0.0m tok/s: 8197540
5/20000 train_loss: 8.6257 train_time: 0.0m tok/s: 8158832
500/20000 train_loss: 3.0617 train_time: 0.8m tok/s: 7952184
1000/20000 train_loss: 2.9819 train_time: 1.7m tok/s: 7929746
1500/20000 train_loss: 2.9701 train_time: 2.5m tok/s: 7920772
2000/20000 train_loss: 2.6938 train_time: 3.3m tok/s: 7916245
2500/20000 train_loss: 2.7396 train_time: 4.1m tok/s: 7918140
3000/20000 train_loss: 2.7789 train_time: 5.0m tok/s: 7919917
recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10]
3500/20000 train_loss: 2.6938 train_time: 6.1m tok/s: 7516573
4000/20000 train_loss: 2.6257 train_time: 7.1m tok/s: 7425494
4000/20000 val_loss: 2.6472 val_bpb: 1.1505
4500/20000 train_loss: 2.5745 train_time: 8.0m tok/s: 7356557
5000/20000 train_loss: 2.5209 train_time: 9.0m tok/s: 7302902
5450/20000 val_loss: 2.5381 val_bpb: 1.1030
stopping_early: wallclock_cap train_time: 590058ms step: 5450/20000
peak memory allocated: 30120 MiB reserved: 30154 MiB
ema:applying EMA weights
pre-quantization post-ema val_loss:2.53669680 val_bpb:1.10241920 eval_time:2002ms
Serialized model: 132406149 bytes
Code size: 24659 bytes
GPTQ:collecting Hessians from calibration data...
GPTQ:collected 66 Hessians in 9.8s
GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search
selective_prune: unpruned=16.59MB target=16.0MB
selective_prune: pruning 4683904/8992183 lowest-error ±1 values (excess=585488B)
Serialized model int6+brotli: 15154845 bytes
Total submission size int6+brotli: 15179504 bytes
final_int6_roundtrip val_loss:2.70991443 val_bpb:1.17769759 eval_time:23408ms
final_int6_sliding_window val_loss:2.66190192 val_bpb:1.15683191 eval_time:99175ms
Loading
Loading