Skip to content

[wip] depth-recurrent QAT: 3x3 shared blocks + LoRA#38

Closed
kxddry wants to merge 32 commits intoopenai:mainfrom
kxddry:main
Closed

[wip] depth-recurrent QAT: 3x3 shared blocks + LoRA#38
kxddry wants to merge 32 commits intoopenai:mainfrom
kxddry:main

Conversation

@kxddry
Copy link
Copy Markdown

@kxddry kxddry commented Mar 19, 2026

Summary

New submission to records/track_10min_16mb/ using depth-recurrent transformer with
quantization-aware training.

Key optimizations

  • Depth recurrence: 3 shared transformer blocks looped 3× with per-iteration LoRA deltas
    (rank 4), replacing 9 unique blocks. Frees ~66% of parameter budget while maintaining
    effective depth.
  • QAT: STE fake-quantize in CastedLinear forward pass. Model trains with INT8 noise —
    reduces post-quantization BPB degradation by 18× (0.0001 vs 0.002 baseline).
  • Byte grouping: Separates int8 weights from fp16 scales before zlib for ~2× better
    compression ratio.
  • Model widening: dim 512→768 using freed parameter budget (13.2M params, ~4.1MB
    artifact).
  • LAWA: Checkpoint averaging during warmdown for free quality boost.
  • RoPE NTK scaling: 2048-token eval context (trained at 1024).

Local validation (MLX, 200 iterations, 1 shard)

Config val_bpb Artifact INT8 degradation
Baseline (512d, 9 blocks) 3.022 5.35MB 0.00217
Ours (768d, 3×3+LoRA) 2.962 4.12MB 0.00012

H100 multi-seed results pending compute grant approval.

Submission checklist

  • records/track_10min_16mb/2026-03-19_DepthRecurrentQAT/README.md
  • records/track_10min_16mb/2026-03-19_DepthRecurrentQAT/submission.json
  • records/track_10min_16mb/2026-03-19_DepthRecurrentQAT/train_gpt.py
  • Training log (pending H100 runs)
  • Multi-seed validation p < 0.01 (pending H100 runs)

0hq and others added 30 commits March 18, 2026 09:33
- Add fake_quantize_int8() with per-row INT8 STE pattern
- Modify CastedLinear.forward() to apply fake-quantize during training only
- Uses full-range amax (no clip percentile) so model learns to avoid outliers
- STE gradient passthrough via (x_q * scale - x).detach() + x
- Add byte_group_quantized() to separate int8 weights from fp16 scales before zlib
- Add byte_ungroup_quantized() for deserialization roundtrip
- Replace torch.save serialization with byte grouping in main()
- Replace torch.load deserialization with byte ungrouping in main()
- Add import struct for binary packing of metadata lengths
- Assert total_submission_bytes < 16_000_000 after serialization
- Descriptive error message shows code bytes, model bytes, and headroom
- Success log confirms artifact size check passed with headroom
- Runs only on master_process (rank 0)
…ging

- Capture float_val_bpb before quantization for comparison
- Compute bpb_degradation after INT8 roundtrip eval
- Log bpb_comparison with 8-decimal precision
- Soft warning if degradation exceeds 0.01 threshold
- Add num_shared_blocks (default 3), num_loops (default 3), lora_rank (default 4)
- Add LoRADelta class with CastedLinear for QAT compatibility
- Zero-init on up projection ensures delta starts as identity
…LoRA deltas

- Replace 9 unique blocks with 3 shared blocks looped 3 times
- Add per-iteration LoRA deltas indexed by effective layer
- Remove U-Net skip connections (skip_weights, encoder/decoder split)
- Update optimizer setup to use shared_blocks and lora_deltas params
- Remove skip_weight from CONTROL_TENSOR_NAME_PATTERNS
- Update GPT instantiation to pass new hyperparameters
- Route shared block params to Muon with LR scaled by 1/sqrt(num_loops)
- Route LoRA delta params to separate Adam optimizer (no gradient scaling)
- Add optimizer_lora to optimizers list for automatic LR warmdown
- Add depth recurrence diagnostics logging (shared_blocks, loops, effective_depth)
- Remove LoRA params from Muon (each LoRA delta used once per forward pass)
- Assert num_loops is plain int (Dynamo unrolls constant loop counts)
- Verify fullgraph=True compatibility with recursive architecture
- File at 1279 lines (under 1500 line cap)
- Change MODEL_DIM default from 512 to 768 (2.25x wider)
- Add log_param_budget() for component-level byte accounting
- Update header comment to reflect new width
- All divisibility constraints verified (head_dim=96, GQA ratio=2)
…pport

- Add NuMuon hyperparameters (NUMUON_ENABLED, NUMUON_R_MIN) to Hyperparameters class
- Extend Muon with _numuon_update using top-k SVD with cosine rank scheduling
- Branch Muon.step() to use NuMuon when enabled (default off)
- Document vocabulary expansion to 2048 with env var instructions (sp2048)
- Set optimizer_muon._total_steps before training loop
- Add NuMuon status to training log output
- Add eval_seq_len hyperparameter (default 2048) to Hyperparameters
- Add Rotary.scale_for_eval() for NTK-aware base frequency scaling
- Extend eval_val() with optional eval_seq_len parameter
- Apply RoPE scaling after training loop for extended-context eval
- Post-quantization eval uses extended context length
- Add lawa_enabled and lawa_interval hyperparameters (env-var togglable)
- Accumulate weight snapshots every lawa_interval steps during warmdown
- Average snapshots after training loop ends, before RoPE scaling
- Re-evaluate with averaged weights and update reference BPB
- Graceful skip with fewer than 2 snapshots
- README.md documenting all 6 optimizations (QAT, depth recurrence, LoRA, widening, LAWA, RoPE)
- submission.json matching competition schema with placeholder metrics
- Verbatim copy of train_gpt.py for self-contained submission
- Runs train_gpt.py with 5 seeds (1337, 42, 7, 2024, 31415) sequentially
- Parses final_int8_zlib_roundtrip_exact val_bpb from each run
- Computes mean, std dev, std error, and improvement over baseline 1.2244
- Evaluates statistical significance (one-sided t-test, p < 0.01, df=4)
- Saves results to multi_seed_results.txt
Port all 4 phases of optimization from train_gpt.py to train_gpt_mlx.py:

1. Depth recurrence architecture: 3 shared blocks x 3 loops with
   per-iteration LoRA deltas, replacing U-Net skip connections
2. QAT fake-quantize: INT8 simulation with STE in CastedLinear forward
3. Byte grouping compression: reorganize quantized state dict for
   better zlib compression ratios
4. LAWA checkpoint averaging: accumulate weight snapshots during
   warmdown and average them after training
5. NTK-aware RoPE scaling for extended-context evaluation (2048 tokens)
6. Parameter budget logging and artifact size assertion
7. BPB degradation logging (float vs quantized roundtrip)

Hyperparameters updated to match PyTorch defaults (model_dim=768,
num_shared_blocks=3, num_loops=3, lora_rank=4, eval_seq_len=2048).
MLX Timing Mismatch with Main Script
Fix MLX multi-batch validation memory growth
Resolve conflicts in train_gpt_mlx.py:
- Keep both eval_seq_len (ours) and log_fn (upstream) params
- Use our seq_len variable with upstream's total_batches/total_loss_sum
- Take upstream's timing fix (exclude eval from train_time_ms)
- Pass both eval_seq_len and log_fn at all eval_val call sites
@chatgpt-codex-connector
Copy link
Copy Markdown

Note

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.
To continue using code reviews, you can upgrade your account or add credits to your account and enable them for code reviews in your settings.

@kxddry kxddry changed the title Add DepthRecurrentQAT submission (3x3 shared blocks + QAT + LoRA) Depth-recurrent QAT: 3x3 shared blocks + LoRA, dim 768, val_bpb 2.962 Mar 19, 2026
@kxddry kxddry changed the title Depth-recurrent QAT: 3x3 shared blocks + LoRA, dim 768, val_bpb 2.962 [wip] depth-recurrent QAT: 3x3 shared blocks + LoRA Mar 19, 2026
Restore original README.md from launch snapshot, reverting
cosmetic edits merged from openai/main.
@0hq 0hq closed this Mar 19, 2026
seconds-0 pushed a commit to seconds-0/parameter-golf that referenced this pull request Mar 20, 2026
E14 (QAT-lite) → 3 alternatives, cheapest first:
  E14b: quant noise injection (2% overhead, run first)
  E14a: STE fake-quantize (10% overhead)
  E14c: learned per-row scales / LSQ-lite (8% overhead)

E31 (MTP) → 3 sequential, each gates the next:
  E31a: 1 aux head (t+2), gate experiment for 15M scale
  E31b: 2 heads with forward curriculum
  E31c: decaying loss weights on E31b architecture

E34 (Muon upgrades) → 3 variants with public code:
  E34a: Polar Express drop-in (proven in speedrun openai#38)
  E34b: Turbo-Muon AOL preconditioning
  E34c: NorMuon neuron-wise adaptive LR (independent axis)

All sub-experiments now have specific implementation, line count
estimate, overhead estimate, code sources, and numeric kill rules.
Ideas section kept as empty placeholder.
gHashTag added a commit to gHashTag/parameter-golf that referenced this pull request Apr 30, 2026
…penai#38)

Both fleet-snapshot.yml and deploy-from-template.yml had a
'git commit -m "..."' string spanning blank lines inside a 'run: |'
block. GitHub Actions YAML scanner reads blank lines as end-of-mapping,
which caused all snapshot/deploy runs to fail with 'startup_failure'
(0 jobs, no logs):

    while scanning a simple key
    could not find expected ':'

Fix: collapse the commit body to a single line with '--' separators.
The information (totals, [skip ci]) is preserved.

Verified: yaml.safe_load passes for both files. Refs trios#143 / openai#16.

Co-authored-by: Perplexity Computer <computer@perplexity.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants