-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Record: 10L CountInitBigram + XSA + PartialRoPE (val_bpb=1.1522) #485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,89 @@ | ||||||
| # 10L CountInitBigram + XSA + PartialRoPE + LN Scale | ||||||
|
|
||||||
| **val_bpb: 1.1522** (sliding window stride=64, post int5/int6+zstd quantization roundtrip) | ||||||
|
|
||||||
| ## Run Command | ||||||
|
|
||||||
| ```bash | ||||||
| torchrun --standalone --nproc_per_node=8 train_gpt.py | ||||||
| ``` | ||||||
|
|
||||||
| All parameters are set as defaults in `train_gpt.py`. No env vars needed. | ||||||
|
|
||||||
| ## Key Techniques | ||||||
|
|
||||||
| ### 1. Count-Initialized Exact Bigram Logit Head (Novel) | ||||||
| A 1024x1024 lookup table providing exact bigram logit biases with zero hash collisions. | ||||||
| Initialized from corpus transition probabilities before training: | ||||||
|
|
||||||
| ``` | ||||||
| B[a,b] = log p(b|a) - log p(b) | ||||||
| ``` | ||||||
|
|
||||||
| Computed from the first 16M training tokens with additive smoothing (alpha=0.25), clipped to [-4, 4]. | ||||||
| Applied BEFORE logit softcap so the bias is properly bounded. | ||||||
| The table is quantized to int4 with nibble packing (524KB vs 1MB at int8). | ||||||
|
|
||||||
| This gives the model a strong count-based language model prior from step 0, which the neural | ||||||
| network only needs to refine. | ||||||
|
|
||||||
| ### 2. Int4 Nibble Packing (Novel) | ||||||
| Custom `pack_i4` / `unpack_i4` functions pack signed int4 values [-8,7] into uint8 bytes | ||||||
| (two values per byte). Applied to the bigram logit table, halving its storage cost. | ||||||
|
|
||||||
| ### 3. XSA (Exclusive Self Attention) - Last 4 Layers | ||||||
| Removes the self-value component from attention output (arxiv:2603.09078). | ||||||
|
|
||||||
| ### 4. Partial RoPE (16 of 64 dims) | ||||||
| Apply rotary position embeddings to only 25% of head dimensions. The remaining 75% | ||||||
| attend without positional bias, acting as position-independent feature detectors. | ||||||
|
|
||||||
| ### 5. LN Scale | ||||||
| Block outputs scaled by `1/sqrt(layer_idx + 1)`. Damps deeper layers' contributions, | ||||||
| stabilizing training. Zero parameters. | ||||||
|
|
||||||
| ### 6. Higher Learning Rates | ||||||
| - matrix_lr: 0.025 (up from 0.02) | ||||||
| - scalar_lr: 0.025 | ||||||
| - tied_embed_lr: 0.035 | ||||||
|
|
||||||
| ## Architecture | ||||||
| - 10 layers, 512 dim, 8 heads, 4 KV heads (GQA) | ||||||
| - MLP 3x expansion (hidden=1536), relu squared activation | ||||||
|
||||||
| - MLP 3x expansion (hidden=1536), relu squared activation | |
| - MLP 3x expansion (hidden=1536), leaky relu squared activation (negative_slope=0.5) |
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README lists “warmdown=2800 iters”, but train_gpt.py defaults WARMDOWN_ITERS to 3000. Please reconcile the README with the actual default hyperparameters used for this record (or explicitly document the env var override used for the run).
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| { | ||
| "name": "10L CountInitBigram + XSA + PartialRoPE + LN Scale", | ||
| "val_loss": 1.94544622, | ||
| "val_bpb": 1.15220588, | ||
| "bytes_total": 15384232, | ||
| "bytes_model_int6_zstd": 15322709, | ||
| "bytes_code": 61523, | ||
| "blurb": "10 layers with count-initialized exact bigram logit head (corpus log-prob residuals, int4 nibble-packed), XSA on last 4 layers, Partial RoPE (16/64 dims), LN Scale (1/sqrt(layer+1)), higher LR (0.025). SWA over 22 checkpoints, sliding window eval stride=64.", | ||
| "author": "Sri Harsha Gouru", | ||
| "github_id": "harsha-gouru", | ||
| "date": "2026-03-22" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| ├── 🔨 Created mount train_gpt_submission.py | ||
| └── 🔨 Created function train_8gpu. | ||
| Downloading: variant=sp1024, train_shards=80... | ||
| seed:42 | ||
| seed:42 | ||
| seed:42 | ||
| seed:42 | ||
| warmup_step:1/20 | ||
| warmup_step:1/20 | ||
| warmup_step:1/20 | ||
| warmup_step:20/20 | ||
| warmup_step:20/20 | ||
| warmup_step:20/20 | ||
| step:200/20000 train_loss:2.3712 train_time:19185ms step_avg:95.93ms | ||
| step:500/20000 val_loss:2.3625 val_bpb:1.3992 train_time:47906ms step_avg:95.81ms | ||
| step:800/20000 train_loss:2.2471 train_time:76829ms step_avg:96.04ms | ||
| step:1100/20000 train_loss:2.3264 train_time:105475ms step_avg:95.89ms | ||
| step:1400/20000 train_loss:2.1932 train_time:134238ms step_avg:95.88ms | ||
| step:1700/20000 train_loss:2.1487 train_time:162862ms step_avg:95.80ms | ||
| step:2000/20000 val_loss:2.1413 val_bpb:1.2682 train_time:191524ms step_avg:95.76ms | ||
| step:2300/20000 train_loss:2.1218 train_time:220300ms step_avg:95.78ms | ||
| step:2600/20000 train_loss:2.1206 train_time:248938ms step_avg:95.75ms | ||
| step:2900/20000 train_loss:2.0369 train_time:277499ms step_avg:95.69ms | ||
| step:3200/20000 train_loss:2.1853 train_time:306283ms step_avg:95.71ms | ||
| step:3500/20000 val_loss:2.0956 val_bpb:1.2411 train_time:334990ms step_avg:95.71ms | ||
| step:3800/20000 train_loss:2.0836 train_time:363568ms step_avg:95.68ms | ||
| step:4100/20000 train_loss:2.0146 train_time:392308ms step_avg:95.68ms | ||
| step:4400/20000 train_loss:2.0390 train_time:420921ms step_avg:95.66ms | ||
| step:4700/20000 train_loss:2.2341 train_time:449553ms step_avg:95.65ms | ||
| step:5000/20000 val_loss:2.0222 val_bpb:1.1976 train_time:478355ms step_avg:95.67ms | ||
| step:5300/20000 train_loss:2.0038 train_time:507084ms step_avg:95.68ms | ||
| step:5600/20000 train_loss:1.9457 train_time:535996ms step_avg:95.71ms | ||
| step:5900/20000 train_loss:1.8864 train_time:564743ms step_avg:95.72ms | ||
| step:6200/20000 train_loss:1.9320 train_time:593509ms step_avg:95.73ms | ||
| final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588 | ||
| val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/data/data/tokenizers/fineweb_1024_bpe.model | ||
| train_loader:dataset:fineweb10B_sp1024 train_shards:80 | ||
| val_loader:shards pattern=/data/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 | ||
| bigram_init:count_based time:0.22s | ||
| model_params:25779282 | ||
| world_size:8 grad_accum_steps:1 | ||
| attention_mode:gqa num_heads:8 num_kv_heads:4 | ||
| tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 | ||
| train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 | ||
| seed:42 | ||
| warmup_step:1/20 | ||
| warmup_step:2/20 | ||
| warmup_step:3/20 | ||
| warmup_step:4/20 | ||
| warmup_step:5/20 | ||
| warmup_step:6/20 | ||
| warmup_step:7/20 | ||
| warmup_step:8/20 | ||
| warmup_step:9/20 | ||
| warmup_step:10/20 | ||
| warmup_step:11/20 | ||
| warmup_step:12/20 | ||
| warmup_step:13/20 | ||
| warmup_step:14/20 | ||
| warmup_step:15/20 | ||
| warmup_step:16/20 | ||
| warmup_step:17/20 | ||
| warmup_step:18/20 | ||
| warmup_step:19/20 | ||
| warmup_step:20/20 | ||
| step:0/20000 val_loss:7.3822 val_bpb:4.3721 train_time:0ms step_avg:0.01ms | ||
| step:1/20000 train_loss:7.3777 train_time:223ms step_avg:223.20ms | ||
| step:2/20000 train_loss:5.1389 train_time:303ms step_avg:151.25ms | ||
| step:3/20000 train_loss:6.8711 train_time:399ms step_avg:132.99ms | ||
| step:4/20000 train_loss:7.9386 train_time:498ms step_avg:124.43ms | ||
| step:5/20000 train_loss:7.0619 train_time:593ms step_avg:118.64ms | ||
| step:6/20000 train_loss:6.4111 train_time:689ms step_avg:114.78ms | ||
| step:7/20000 train_loss:5.8449 train_time:785ms step_avg:112.11ms | ||
| step:8/20000 train_loss:5.4955 train_time:879ms step_avg:109.93ms | ||
| step:9/20000 train_loss:5.4268 train_time:974ms step_avg:108.26ms | ||
| step:10/20000 train_loss:5.1444 train_time:1069ms step_avg:106.90ms | ||
| step:100/20000 train_loss:3.1036 train_time:9588ms step_avg:95.88ms | ||
| step:200/20000 train_loss:2.3712 train_time:19185ms step_avg:95.93ms | ||
| step:300/20000 train_loss:2.5539 train_time:28785ms step_avg:95.95ms | ||
| step:400/20000 train_loss:2.4180 train_time:38370ms step_avg:95.92ms | ||
| step:500/20000 train_loss:2.4018 train_time:47882ms step_avg:95.76ms | ||
| step:500/20000 val_loss:2.3625 val_bpb:1.3992 train_time:47906ms step_avg:95.81ms | ||
| step:600/20000 train_loss:2.3373 train_time:57505ms step_avg:95.84ms | ||
| step:700/20000 train_loss:2.3546 train_time:67225ms step_avg:96.04ms | ||
| step:800/20000 train_loss:2.2471 train_time:76829ms step_avg:96.04ms | ||
| step:900/20000 train_loss:2.1361 train_time:86414ms step_avg:96.02ms | ||
| step:1000/20000 train_loss:2.2881 train_time:95886ms step_avg:95.89ms | ||
| step:1000/20000 val_loss:2.2347 val_bpb:1.3235 train_time:95910ms step_avg:95.91ms | ||
| step:1100/20000 train_loss:2.3264 train_time:105475ms step_avg:95.89ms | ||
| step:1200/20000 train_loss:2.3602 train_time:115065ms step_avg:95.89ms | ||
| step:1300/20000 train_loss:2.1075 train_time:124651ms step_avg:95.89ms | ||
| step:1400/20000 train_loss:2.1932 train_time:134238ms step_avg:95.88ms | ||
| step:1500/20000 train_loss:2.2303 train_time:143696ms step_avg:95.80ms | ||
| step:1500/20000 val_loss:2.1931 val_bpb:1.2989 train_time:143719ms step_avg:95.81ms | ||
| step:1600/20000 train_loss:2.0843 train_time:153282ms step_avg:95.80ms | ||
| step:1700/20000 train_loss:2.1487 train_time:162862ms step_avg:95.80ms | ||
| step:1800/20000 train_loss:2.1646 train_time:172451ms step_avg:95.81ms | ||
| step:1900/20000 train_loss:2.1335 train_time:181904ms step_avg:95.74ms | ||
| step:2000/20000 train_loss:2.0755 train_time:191501ms step_avg:95.75ms | ||
| step:2000/20000 val_loss:2.1413 val_bpb:1.2682 train_time:191524ms step_avg:95.76ms | ||
| step:2100/20000 train_loss:2.0542 train_time:201099ms step_avg:95.76ms | ||
| step:2200/20000 train_loss:2.1418 train_time:210687ms step_avg:95.77ms | ||
| step:2300/20000 train_loss:2.1218 train_time:220300ms step_avg:95.78ms | ||
| step:2400/20000 train_loss:2.0761 train_time:229761ms step_avg:95.73ms | ||
| step:2500/20000 train_loss:2.1798 train_time:239357ms step_avg:95.74ms | ||
| step:2500/20000 val_loss:2.1160 val_bpb:1.2532 train_time:239380ms step_avg:95.75ms | ||
| step:2600/20000 train_loss:2.1206 train_time:248938ms step_avg:95.75ms | ||
| step:2700/20000 train_loss:2.1137 train_time:258488ms step_avg:95.74ms | ||
| step:2800/20000 train_loss:2.1632 train_time:268049ms step_avg:95.73ms | ||
| step:2900/20000 train_loss:2.0369 train_time:277499ms step_avg:95.69ms | ||
| step:3000/20000 train_loss:2.1721 train_time:287094ms step_avg:95.70ms | ||
| step:3000/20000 val_loss:2.1034 val_bpb:1.2457 train_time:287118ms step_avg:95.71ms | ||
| step:3100/20000 train_loss:2.0467 train_time:296656ms step_avg:95.70ms | ||
| step:3200/20000 train_loss:2.1853 train_time:306283ms step_avg:95.71ms | ||
| step:3300/20000 train_loss:2.0843 train_time:315738ms step_avg:95.68ms | ||
| step:3400/20000 train_loss:2.0352 train_time:325337ms step_avg:95.69ms | ||
| step:3500/20000 train_loss:2.1954 train_time:334966ms step_avg:95.70ms | ||
| step:3500/20000 val_loss:2.0956 val_bpb:1.2411 train_time:334990ms step_avg:95.71ms | ||
| step:3600/20000 train_loss:2.1086 train_time:344546ms step_avg:95.71ms | ||
| step:3700/20000 train_loss:2.1098 train_time:354112ms step_avg:95.71ms | ||
| step:3800/20000 train_loss:2.0836 train_time:363568ms step_avg:95.68ms | ||
| step:3900/20000 train_loss:2.0833 train_time:373128ms step_avg:95.67ms | ||
| step:4000/20000 train_loss:1.9800 train_time:382712ms step_avg:95.68ms | ||
| step:4000/20000 val_loss:2.0729 val_bpb:1.2277 train_time:382735ms step_avg:95.68ms | ||
| step:4100/20000 train_loss:2.0146 train_time:392308ms step_avg:95.68ms | ||
| step:4200/20000 train_loss:2.1597 train_time:401902ms step_avg:95.69ms | ||
| step:4300/20000 train_loss:2.0659 train_time:411352ms step_avg:95.66ms | ||
| step:4400/20000 train_loss:2.0390 train_time:420921ms step_avg:95.66ms | ||
| step:4500/20000 train_loss:2.1277 train_time:430528ms step_avg:95.67ms | ||
| step:4500/20000 val_loss:2.0468 val_bpb:1.2122 train_time:430552ms step_avg:95.68ms | ||
| step:4600/20000 train_loss:1.8398 train_time:440101ms step_avg:95.67ms | ||
| step:4700/20000 train_loss:2.2341 train_time:449553ms step_avg:95.65ms | ||
| step:4800/20000 train_loss:2.4262 train_time:459158ms step_avg:95.66ms | ||
| step:4900/20000 train_loss:2.0473 train_time:468724ms step_avg:95.66ms | ||
| step:5000/20000 train_loss:2.1047 train_time:478331ms step_avg:95.67ms | ||
| step:5000/20000 val_loss:2.0222 val_bpb:1.1976 train_time:478355ms step_avg:95.67ms | ||
| step:5100/20000 train_loss:2.1237 train_time:487891ms step_avg:95.66ms | ||
| swa:start step:5200 | ||
| step:5200/20000 train_loss:2.0397 train_time:497349ms step_avg:95.64ms | ||
| step:5300/20000 train_loss:2.0038 train_time:507084ms step_avg:95.68ms | ||
| step:5400/20000 train_loss:2.0424 train_time:516727ms step_avg:95.69ms | ||
| step:5500/20000 train_loss:2.0107 train_time:526366ms step_avg:95.70ms | ||
| step:5500/20000 val_loss:1.9943 val_bpb:1.1811 train_time:526416ms step_avg:95.71ms | ||
| step:5600/20000 train_loss:1.9457 train_time:535996ms step_avg:95.71ms | ||
| step:5700/20000 train_loss:2.0044 train_time:545480ms step_avg:95.70ms | ||
| step:5800/20000 train_loss:1.9851 train_time:555138ms step_avg:95.71ms | ||
| step:5900/20000 train_loss:1.8864 train_time:564743ms step_avg:95.72ms | ||
| step:6000/20000 train_loss:1.9267 train_time:574385ms step_avg:95.73ms | ||
| step:6000/20000 val_loss:1.9648 val_bpb:1.1636 train_time:574435ms step_avg:95.74ms | ||
| step:6100/20000 train_loss:1.9044 train_time:583877ms step_avg:95.72ms | ||
| step:6200/20000 train_loss:1.9320 train_time:593509ms step_avg:95.73ms | ||
| step:6267/20000 val_loss:1.9523 val_bpb:1.1563 train_time:600046ms step_avg:95.75ms | ||
| stopping_early: wallclock_cap train_time:600046ms step:6267/20000 | ||
| peak memory allocated: 19609 MiB reserved: 19878 MiB | ||
| swa:applying averaged 22 checkpoints | ||
| Serialized model: 98962351 bytes | ||
| Serialized model int6+zstd: 15322709 bytes | ||
| final_eval_mode:sliding_window stride:64 batch_seqs:32 | ||
| final_int8_zlib_roundtrip val_loss:1.9454 val_bpb:1.1522 eval_time:183742ms | ||
| final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README claims “All parameters are set as defaults … No env vars needed.”, but
train_gpt.pydefaultsBIGRAM_LOGIT_HEADto 0, so running the command as written won’t enable the count-initialized exact bigram logit head described in this record. Either set the default to enabled for this record script, or update the run command/docs to includeBIGRAM_LOGIT_HEAD=1(and any other required env vars) so results are reproducible.