Skip to content

Record: 10L CountInitBigram + XSA + PartialRoPE (val_bpb=1.1522)#482

Closed
harsha-gouru wants to merge 1 commit intoopenai:mainfrom
harsha-gouru:submission/2026-03-22_CountInitBigram_XSA
Closed

Record: 10L CountInitBigram + XSA + PartialRoPE (val_bpb=1.1522)#482
harsha-gouru wants to merge 1 commit intoopenai:mainfrom
harsha-gouru:submission/2026-03-22_CountInitBigram_XSA

Conversation

@harsha-gouru
Copy link
Copy Markdown

Summary

  • val_bpb: 1.1522 (sliding window stride=64, post int5/int6+zstd quantization roundtrip)
  • 10 layers, 512 dim, 8 heads / 4 KV heads, tied embeddings
  • Artifact: 15,384,232 bytes (15.38 MB)

Novel Contributions

Count-Initialized Exact Bigram Logit Head

A 1024x1024 lookup table initialized from corpus bigram transition probabilities (B[a,b] = log p(b|a) - log p(b)) before training begins. Provides a strong Markov prior from step 0 — the neural network only needs to refine it. Applied BEFORE logit softcap.

Int4 Nibble Packing

Custom pack_i4/unpack_i4 functions pack signed int4 values into uint8 bytes (two values per byte). Applied to the bigram logit table, halving its storage from ~1MB to ~524KB.

Adopted Techniques

  • XSA on last 4 layers (arxiv:2603.09078)
  • Partial RoPE 16/64 dims
  • LN Scale 1/sqrt(layer+1)
  • Higher LR (0.025)
  • SmearGate, OrthoInit, U-Net skips, SWA, int5/int6 + zstd-22

Results

Steps: 6267 at 95.75 ms/step (8xH100 SXM, 600s wallclock)
Pre-SWA val_bpb: 1.1563
Post-SWA+quant val_bpb: 1.1522
Artifact: 15.38 MB (0.62 MB headroom)

Test plan

  • Verified artifact under 16MB limit
  • Post-quant roundtrip bpb verified
  • 8xH100 SXM, 600s wallclock

Built on baseline by @thwu1 (PR #180).

Copilot AI review requested due to automatic review settings March 23, 2026 00:28
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new /records submission for the 10min/16MB track with a count-initialized exact bigram logit head (int4 nibble-packed) and several architecture/training tweaks (XSA, partial RoPE, LN scaling), alongside the run log and submission metadata.

Changes:

  • Introduces a new record training script implementing CountInitBigram (exact bigram logit bias) and int4 nibble packing in the export path.
  • Adds documentation and metadata (README + submission.json) describing the run and results.
  • Includes the training log capturing the reported metrics and artifact size.

Reviewed changes

Copilot reviewed 3 out of 4 changed files in this pull request and generated 3 comments.

File Description
records/track_10min_16mb/2026-03-22_10L_CountInitBigram_XSA_PartialRoPE/train_gpt.py New record training script with CountInitBigram + int4 packing + XSA/PartialRoPE/LN scaling and quantized export.
records/track_10min_16mb/2026-03-22_10L_CountInitBigram_XSA_PartialRoPE/train.log Training/eval output for the record run.
records/track_10min_16mb/2026-03-22_10L_CountInitBigram_XSA_PartialRoPE/submission.json Submission metadata (metrics, sizes, blurb, author).
records/track_10min_16mb/2026-03-22_10L_CountInitBigram_XSA_PartialRoPE/README.md Record description and reproduction command/technique summary.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +99 to +105
# Exact bigram logit head (can be enabled alongside BigramHash)
bigram_logit_head = bool(int(os.environ.get("BIGRAM_LOGIT_HEAD", "0")))

# Architectural improvements from top PRs
xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last N layers
rope_dims = int(os.environ.get("ROPE_DIMS", 16)) # Partial RoPE: rotate this many dims
ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) # RMSNorm scale by 1/sqrt(layer+1)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hyperparameters.bigram_logit_head defaults to 0, but this record/README/submission claims the count-initialized BigramLogitHead is part of the default run. With the current default, the bigram head (and count-based init + int4 packing) won’t be active unless BIGRAM_LOGIT_HEAD=1 is set, which makes the provided run command non-reproducible as written. Either flip the default to 1 for this record script, or update the README run command (and/or submission blurb) to explicitly set BIGRAM_LOGIT_HEAD=1.

Copilot uses AI. Check for mistakes.
Comment on lines +8 to +12
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

All parameters are set as defaults in `train_gpt.py`. No env vars needed.

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The README says “All parameters are set as defaults… No env vars needed”, but train_gpt.py has BIGRAM_LOGIT_HEAD defaulting to 0. As a result, running the documented command won’t enable the count-initialized exact bigram logit head unless an env var is provided (or the script default is changed). Please update the README command/notes to reflect the needed env var(s), or change the script defaults to match the documented configuration.

Copilot uses AI. Check for mistakes.
Comment on lines +1407 to +1411
log0(
f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} "
f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms"
)
log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}")
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few log lines hardcode "int8+zlib" / "final_int8_zlib_roundtrip" even though this script may actually use zstd (and uses mixed int5/int6 + optional int4 packed tensors). This makes the logs misleading when _COMPRESSOR == 'zstd'. Consider incorporating {_COMPRESSOR} (and the actual quant scheme) into these log message strings so the recorded artifact details match what was produced.

Suggested change
log0(
f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} "
f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms"
)
log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}")
quant_scheme = "mixed_int6" # matches dequantize_mixed_int6 / mixed int5/6 (+ optional int4) scheme
log0(
f"final_{quant_scheme}_{_COMPRESSOR}_roundtrip "
f"val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} "
f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms"
)
log0(
f"final_{quant_scheme}_{_COMPRESSOR}_roundtrip_exact "
f"val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}"
)

Copilot uses AI. Check for mistakes.
@harsha-gouru harsha-gouru marked this pull request as draft March 23, 2026 00:35
@harsha-gouru harsha-gouru deleted the submission/2026-03-22_CountInitBigram_XSA branch March 23, 2026 00:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants