-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Record: 10L CountInitBigram + XSA + PartialRoPE (val_bpb=1.1522) #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
harsha-gouru
wants to merge
1
commit into
openai:main
from
harsha-gouru:submission/2026-03-22_CountInitBigram_XSA
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
89 changes: 89 additions & 0 deletions
89
records/track_10min_16mb/2026-03-22_10L_CountInitBigram_XSA_PartialRoPE/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # 10L CountInitBigram + XSA + PartialRoPE + LN Scale | ||
|
|
||
| **val_bpb: 1.1522** (sliding window stride=64, post int5/int6+zstd quantization roundtrip) | ||
|
|
||
| ## Run Command | ||
|
|
||
| ```bash | ||
| torchrun --standalone --nproc_per_node=8 train_gpt.py | ||
| ``` | ||
|
|
||
| All parameters are set as defaults in `train_gpt.py`. No env vars needed. | ||
|
|
||
| ## Key Techniques | ||
|
|
||
| ### 1. Count-Initialized Exact Bigram Logit Head (Novel) | ||
| A 1024x1024 lookup table providing exact bigram logit biases with zero hash collisions. | ||
| Initialized from corpus transition probabilities before training: | ||
|
|
||
| ``` | ||
| B[a,b] = log p(b|a) - log p(b) | ||
| ``` | ||
|
|
||
| Computed from the first 16M training tokens with additive smoothing (alpha=0.25), clipped to [-4, 4]. | ||
| Applied BEFORE logit softcap so the bias is properly bounded. | ||
| The table is quantized to int4 with nibble packing (524KB vs 1MB at int8). | ||
|
|
||
| This gives the model a strong count-based language model prior from step 0, which the neural | ||
| network only needs to refine. | ||
|
|
||
| ### 2. Int4 Nibble Packing (Novel) | ||
| Custom `pack_i4` / `unpack_i4` functions pack signed int4 values [-8,7] into uint8 bytes | ||
| (two values per byte). Applied to the bigram logit table, halving its storage cost. | ||
|
|
||
| ### 3. XSA (Exclusive Self Attention) - Last 4 Layers | ||
| Removes the self-value component from attention output (arxiv:2603.09078). | ||
|
|
||
| ### 4. Partial RoPE (16 of 64 dims) | ||
| Apply rotary position embeddings to only 25% of head dimensions. The remaining 75% | ||
| attend without positional bias, acting as position-independent feature detectors. | ||
|
|
||
| ### 5. LN Scale | ||
| Block outputs scaled by `1/sqrt(layer_idx + 1)`. Damps deeper layers' contributions, | ||
| stabilizing training. Zero parameters. | ||
|
|
||
| ### 6. Higher Learning Rates | ||
| - matrix_lr: 0.025 (up from 0.02) | ||
| - scalar_lr: 0.025 | ||
| - tied_embed_lr: 0.035 | ||
|
|
||
| ## Architecture | ||
| - 10 layers, 512 dim, 8 heads, 4 KV heads (GQA) | ||
| - MLP 3x expansion (hidden=1536), relu squared activation | ||
| - SmearGate + exact BigramLogitHead (count-initialized, int4 packed) | ||
| - Orthogonal init with muP-scaled output projections | ||
| - U-Net skip connections, tied embeddings | ||
|
|
||
| ## Training Hyperparameters | ||
| - Muon optimizer: matrix_lr=0.025, WD=0.04, momentum=0.99 | ||
| - AdamW for embeddings/scalars: WD=0.04 | ||
| - warmdown=2800 iters, warmup=20 steps | ||
| - seq_len=2048, batch=786K tokens | ||
| - grad_clip=0.3, 3% magnitude pruning | ||
| - SWA: start_frac=0.4, every=50 steps (22 checkpoints) | ||
| - Sliding window eval: stride=64 | ||
|
|
||
| ## Quantization | ||
| - Int5 [-16,15] for MLP weights | ||
| - Int6 [-32,31] for attention weights | ||
| - Int4 [-8,7] nibble-packed for bigram logit table | ||
| - FP16 for tied embeddings | ||
| - zstd-22 compression | ||
|
|
||
| ## Results | ||
| ``` | ||
| Steps completed: 6267 (wallclock capped at 600s) | ||
| Step time: 95.75 ms/step | ||
| Peak memory: 19609 MiB allocated, 19878 MiB reserved | ||
|
|
||
| Pre-SWA val_bpb: 1.1563 | ||
| Post-SWA+quant val_bpb: 1.1522 | ||
| Quant gap: 0.004 bpb | ||
|
|
||
| Artifact size: 15,322,709 bytes (int6+zstd) | ||
| Code size: 61,523 bytes | ||
| Total: 15,384,232 bytes (under 16,000,000 limit) | ||
| ``` | ||
|
|
||
| Built on the baseline by @thwu1 (PR #180). Adopts XSA from arxiv:2603.09078, | ||
| Partial RoPE and LN Scale from PR #315. | ||
12 changes: 12 additions & 0 deletions
12
records/track_10min_16mb/2026-03-22_10L_CountInitBigram_XSA_PartialRoPE/submission.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| { | ||
| "name": "10L CountInitBigram + XSA + PartialRoPE + LN Scale", | ||
| "val_loss": 1.94544622, | ||
| "val_bpb": 1.15220588, | ||
| "bytes_total": 15384232, | ||
| "bytes_model_int6_zstd": 15322709, | ||
| "bytes_code": 61523, | ||
| "blurb": "10 layers with count-initialized exact bigram logit head (corpus log-prob residuals, int4 nibble-packed), XSA on last 4 layers, Partial RoPE (16/64 dims), LN Scale (1/sqrt(layer+1)), higher LR (0.025). SWA over 22 checkpoints, sliding window eval stride=64.", | ||
| "author": "Sri Harsha Gouru", | ||
| "github_id": "harsha-gouru", | ||
| "date": "2026-03-22" | ||
| } |
164 changes: 164 additions & 0 deletions
164
records/track_10min_16mb/2026-03-22_10L_CountInitBigram_XSA_PartialRoPE/train.log
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| ├── 🔨 Created mount train_gpt_submission.py | ||
| └── 🔨 Created function train_8gpu. | ||
| Downloading: variant=sp1024, train_shards=80... | ||
| seed:42 | ||
| seed:42 | ||
| seed:42 | ||
| seed:42 | ||
| warmup_step:1/20 | ||
| warmup_step:1/20 | ||
| warmup_step:1/20 | ||
| warmup_step:20/20 | ||
| warmup_step:20/20 | ||
| warmup_step:20/20 | ||
| step:200/20000 train_loss:2.3712 train_time:19185ms step_avg:95.93ms | ||
| step:500/20000 val_loss:2.3625 val_bpb:1.3992 train_time:47906ms step_avg:95.81ms | ||
| step:800/20000 train_loss:2.2471 train_time:76829ms step_avg:96.04ms | ||
| step:1100/20000 train_loss:2.3264 train_time:105475ms step_avg:95.89ms | ||
| step:1400/20000 train_loss:2.1932 train_time:134238ms step_avg:95.88ms | ||
| step:1700/20000 train_loss:2.1487 train_time:162862ms step_avg:95.80ms | ||
| step:2000/20000 val_loss:2.1413 val_bpb:1.2682 train_time:191524ms step_avg:95.76ms | ||
| step:2300/20000 train_loss:2.1218 train_time:220300ms step_avg:95.78ms | ||
| step:2600/20000 train_loss:2.1206 train_time:248938ms step_avg:95.75ms | ||
| step:2900/20000 train_loss:2.0369 train_time:277499ms step_avg:95.69ms | ||
| step:3200/20000 train_loss:2.1853 train_time:306283ms step_avg:95.71ms | ||
| step:3500/20000 val_loss:2.0956 val_bpb:1.2411 train_time:334990ms step_avg:95.71ms | ||
| step:3800/20000 train_loss:2.0836 train_time:363568ms step_avg:95.68ms | ||
| step:4100/20000 train_loss:2.0146 train_time:392308ms step_avg:95.68ms | ||
| step:4400/20000 train_loss:2.0390 train_time:420921ms step_avg:95.66ms | ||
| step:4700/20000 train_loss:2.2341 train_time:449553ms step_avg:95.65ms | ||
| step:5000/20000 val_loss:2.0222 val_bpb:1.1976 train_time:478355ms step_avg:95.67ms | ||
| step:5300/20000 train_loss:2.0038 train_time:507084ms step_avg:95.68ms | ||
| step:5600/20000 train_loss:1.9457 train_time:535996ms step_avg:95.71ms | ||
| step:5900/20000 train_loss:1.8864 train_time:564743ms step_avg:95.72ms | ||
| step:6200/20000 train_loss:1.9320 train_time:593509ms step_avg:95.73ms | ||
| Final summary: {'run_id': 'submission-10L-v1', 'gpus': 8, 'exit_code': 0, 'elapsed_s': 1140.4, 'artifact_bytes': 15322709, 'code_bytes': 61523, 'total_bytes': 15384232, 'under_limit': True, 'headroom_bytes': 615768, 'final_line': 'final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588', 'val_bpb': 1.15220588} | ||
| val_bpb = 1.152206 | ||
| final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588 | ||
| val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/data/data/tokenizers/fineweb_1024_bpe.model | ||
| train_loader:dataset:fineweb10B_sp1024 train_shards:80 | ||
| val_loader:shards pattern=/data/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 | ||
| bigram_init:count_based time:0.22s | ||
| model_params:25779282 | ||
| world_size:8 grad_accum_steps:1 | ||
| attention_mode:gqa num_heads:8 num_kv_heads:4 | ||
| tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 | ||
| train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 | ||
| seed:42 | ||
| warmup_step:1/20 | ||
| warmup_step:2/20 | ||
| warmup_step:3/20 | ||
| warmup_step:4/20 | ||
| warmup_step:5/20 | ||
| warmup_step:6/20 | ||
| warmup_step:7/20 | ||
| warmup_step:8/20 | ||
| warmup_step:9/20 | ||
| warmup_step:10/20 | ||
| warmup_step:11/20 | ||
| warmup_step:12/20 | ||
| warmup_step:13/20 | ||
| warmup_step:14/20 | ||
| warmup_step:15/20 | ||
| warmup_step:16/20 | ||
| warmup_step:17/20 | ||
| warmup_step:18/20 | ||
| warmup_step:19/20 | ||
| warmup_step:20/20 | ||
| step:0/20000 val_loss:7.3822 val_bpb:4.3721 train_time:0ms step_avg:0.01ms | ||
| step:1/20000 train_loss:7.3777 train_time:223ms step_avg:223.20ms | ||
| step:2/20000 train_loss:5.1389 train_time:303ms step_avg:151.25ms | ||
| step:3/20000 train_loss:6.8711 train_time:399ms step_avg:132.99ms | ||
| step:4/20000 train_loss:7.9386 train_time:498ms step_avg:124.43ms | ||
| step:5/20000 train_loss:7.0619 train_time:593ms step_avg:118.64ms | ||
| step:6/20000 train_loss:6.4111 train_time:689ms step_avg:114.78ms | ||
| step:7/20000 train_loss:5.8449 train_time:785ms step_avg:112.11ms | ||
| step:8/20000 train_loss:5.4955 train_time:879ms step_avg:109.93ms | ||
| step:9/20000 train_loss:5.4268 train_time:974ms step_avg:108.26ms | ||
| step:10/20000 train_loss:5.1444 train_time:1069ms step_avg:106.90ms | ||
| step:100/20000 train_loss:3.1036 train_time:9588ms step_avg:95.88ms | ||
| step:200/20000 train_loss:2.3712 train_time:19185ms step_avg:95.93ms | ||
| step:300/20000 train_loss:2.5539 train_time:28785ms step_avg:95.95ms | ||
| step:400/20000 train_loss:2.4180 train_time:38370ms step_avg:95.92ms | ||
| step:500/20000 train_loss:2.4018 train_time:47882ms step_avg:95.76ms | ||
| step:500/20000 val_loss:2.3625 val_bpb:1.3992 train_time:47906ms step_avg:95.81ms | ||
| step:600/20000 train_loss:2.3373 train_time:57505ms step_avg:95.84ms | ||
| step:700/20000 train_loss:2.3546 train_time:67225ms step_avg:96.04ms | ||
| step:800/20000 train_loss:2.2471 train_time:76829ms step_avg:96.04ms | ||
| step:900/20000 train_loss:2.1361 train_time:86414ms step_avg:96.02ms | ||
| step:1000/20000 train_loss:2.2881 train_time:95886ms step_avg:95.89ms | ||
| step:1000/20000 val_loss:2.2347 val_bpb:1.3235 train_time:95910ms step_avg:95.91ms | ||
| step:1100/20000 train_loss:2.3264 train_time:105475ms step_avg:95.89ms | ||
| step:1200/20000 train_loss:2.3602 train_time:115065ms step_avg:95.89ms | ||
| step:1300/20000 train_loss:2.1075 train_time:124651ms step_avg:95.89ms | ||
| step:1400/20000 train_loss:2.1932 train_time:134238ms step_avg:95.88ms | ||
| step:1500/20000 train_loss:2.2303 train_time:143696ms step_avg:95.80ms | ||
| step:1500/20000 val_loss:2.1931 val_bpb:1.2989 train_time:143719ms step_avg:95.81ms | ||
| step:1600/20000 train_loss:2.0843 train_time:153282ms step_avg:95.80ms | ||
| step:1700/20000 train_loss:2.1487 train_time:162862ms step_avg:95.80ms | ||
| step:1800/20000 train_loss:2.1646 train_time:172451ms step_avg:95.81ms | ||
| step:1900/20000 train_loss:2.1335 train_time:181904ms step_avg:95.74ms | ||
| step:2000/20000 train_loss:2.0755 train_time:191501ms step_avg:95.75ms | ||
| step:2000/20000 val_loss:2.1413 val_bpb:1.2682 train_time:191524ms step_avg:95.76ms | ||
| step:2100/20000 train_loss:2.0542 train_time:201099ms step_avg:95.76ms | ||
| step:2200/20000 train_loss:2.1418 train_time:210687ms step_avg:95.77ms | ||
| step:2300/20000 train_loss:2.1218 train_time:220300ms step_avg:95.78ms | ||
| step:2400/20000 train_loss:2.0761 train_time:229761ms step_avg:95.73ms | ||
| step:2500/20000 train_loss:2.1798 train_time:239357ms step_avg:95.74ms | ||
| step:2500/20000 val_loss:2.1160 val_bpb:1.2532 train_time:239380ms step_avg:95.75ms | ||
| step:2600/20000 train_loss:2.1206 train_time:248938ms step_avg:95.75ms | ||
| step:2700/20000 train_loss:2.1137 train_time:258488ms step_avg:95.74ms | ||
| step:2800/20000 train_loss:2.1632 train_time:268049ms step_avg:95.73ms | ||
| step:2900/20000 train_loss:2.0369 train_time:277499ms step_avg:95.69ms | ||
| step:3000/20000 train_loss:2.1721 train_time:287094ms step_avg:95.70ms | ||
| step:3000/20000 val_loss:2.1034 val_bpb:1.2457 train_time:287118ms step_avg:95.71ms | ||
| step:3100/20000 train_loss:2.0467 train_time:296656ms step_avg:95.70ms | ||
| step:3200/20000 train_loss:2.1853 train_time:306283ms step_avg:95.71ms | ||
| step:3300/20000 train_loss:2.0843 train_time:315738ms step_avg:95.68ms | ||
| step:3400/20000 train_loss:2.0352 train_time:325337ms step_avg:95.69ms | ||
| step:3500/20000 train_loss:2.1954 train_time:334966ms step_avg:95.70ms | ||
| step:3500/20000 val_loss:2.0956 val_bpb:1.2411 train_time:334990ms step_avg:95.71ms | ||
| step:3600/20000 train_loss:2.1086 train_time:344546ms step_avg:95.71ms | ||
| step:3700/20000 train_loss:2.1098 train_time:354112ms step_avg:95.71ms | ||
| step:3800/20000 train_loss:2.0836 train_time:363568ms step_avg:95.68ms | ||
| step:3900/20000 train_loss:2.0833 train_time:373128ms step_avg:95.67ms | ||
| step:4000/20000 train_loss:1.9800 train_time:382712ms step_avg:95.68ms | ||
| step:4000/20000 val_loss:2.0729 val_bpb:1.2277 train_time:382735ms step_avg:95.68ms | ||
| step:4100/20000 train_loss:2.0146 train_time:392308ms step_avg:95.68ms | ||
| step:4200/20000 train_loss:2.1597 train_time:401902ms step_avg:95.69ms | ||
| step:4300/20000 train_loss:2.0659 train_time:411352ms step_avg:95.66ms | ||
| step:4400/20000 train_loss:2.0390 train_time:420921ms step_avg:95.66ms | ||
| step:4500/20000 train_loss:2.1277 train_time:430528ms step_avg:95.67ms | ||
| step:4500/20000 val_loss:2.0468 val_bpb:1.2122 train_time:430552ms step_avg:95.68ms | ||
| step:4600/20000 train_loss:1.8398 train_time:440101ms step_avg:95.67ms | ||
| step:4700/20000 train_loss:2.2341 train_time:449553ms step_avg:95.65ms | ||
| step:4800/20000 train_loss:2.4262 train_time:459158ms step_avg:95.66ms | ||
| step:4900/20000 train_loss:2.0473 train_time:468724ms step_avg:95.66ms | ||
| step:5000/20000 train_loss:2.1047 train_time:478331ms step_avg:95.67ms | ||
| step:5000/20000 val_loss:2.0222 val_bpb:1.1976 train_time:478355ms step_avg:95.67ms | ||
| step:5100/20000 train_loss:2.1237 train_time:487891ms step_avg:95.66ms | ||
| swa:start step:5200 | ||
| step:5200/20000 train_loss:2.0397 train_time:497349ms step_avg:95.64ms | ||
| step:5300/20000 train_loss:2.0038 train_time:507084ms step_avg:95.68ms | ||
| step:5400/20000 train_loss:2.0424 train_time:516727ms step_avg:95.69ms | ||
| step:5500/20000 train_loss:2.0107 train_time:526366ms step_avg:95.70ms | ||
| step:5500/20000 val_loss:1.9943 val_bpb:1.1811 train_time:526416ms step_avg:95.71ms | ||
| step:5600/20000 train_loss:1.9457 train_time:535996ms step_avg:95.71ms | ||
| step:5700/20000 train_loss:2.0044 train_time:545480ms step_avg:95.70ms | ||
| step:5800/20000 train_loss:1.9851 train_time:555138ms step_avg:95.71ms | ||
| step:5900/20000 train_loss:1.8864 train_time:564743ms step_avg:95.72ms | ||
| step:6000/20000 train_loss:1.9267 train_time:574385ms step_avg:95.73ms | ||
| step:6000/20000 val_loss:1.9648 val_bpb:1.1636 train_time:574435ms step_avg:95.74ms | ||
| step:6100/20000 train_loss:1.9044 train_time:583877ms step_avg:95.72ms | ||
| step:6200/20000 train_loss:1.9320 train_time:593509ms step_avg:95.73ms | ||
| step:6267/20000 val_loss:1.9523 val_bpb:1.1563 train_time:600046ms step_avg:95.75ms | ||
| stopping_early: wallclock_cap train_time:600046ms step:6267/20000 | ||
| peak memory allocated: 19609 MiB reserved: 19878 MiB | ||
| swa:applying averaged 22 checkpoints | ||
| Serialized model: 98962351 bytes | ||
| Serialized model int6+zstd: 15322709 bytes | ||
| final_eval_mode:sliding_window stride:64 batch_seqs:32 | ||
| final_int8_zlib_roundtrip val_loss:1.9454 val_bpb:1.1522 eval_time:183742ms | ||
| final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588 | ||
| "final_line": "final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588", | ||
| "val_bpb": 1.15220588 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The README says “All parameters are set as defaults… No env vars needed”, but
train_gpt.pyhasBIGRAM_LOGIT_HEADdefaulting to 0. As a result, running the documented command won’t enable the count-initialized exact bigram logit head unless an env var is provided (or the script default is changed). Please update the README command/notes to reflect the needed env var(s), or change the script defaults to match the documented configuration.