Skip to content

Non-record: Byte-level transformer + JEPA auxiliary loss (val_bpb: 1.1903)#832

Open
jfprincz wants to merge 1 commit intoopenai:mainfrom
jfprincz:submission/byte-jepa-compression-1.1903
Open

Non-record: Byte-level transformer + JEPA auxiliary loss (val_bpb: 1.1903)#832
jfprincz wants to merge 1 commit intoopenai:mainfrom
jfprincz:submission/byte-jepa-compression-1.1903

Conversation

@jfprincz
Copy link
Copy Markdown
Contributor

Non-record: Byte-level transformer + JEPA auxiliary loss (val_bpb: 1.1903)

val_bpb: 1.1903 (sliding window, stride=512) | 14.4 MB | 8xH100 SXM, 600s

Byte-level autoregressive transformer (vocab 260, no tokenizer) with a lightweight JEPA auxiliary loss contributing ~0.1% of peak gradient signal. Beats the sp1024 baseline (1.2244) by 0.034 BPB.

Ablation: JEPA contribution

Without JEPA With JEPA Delta
Int6 sliding s512 1.2006 1.1905 -0.0101
Step time 60ms 63ms +3ms
Params 24.2M 24.6M +459K

JEPA adds 0.01 BPB improvement at 5% overhead. The improvement is consistent across seeds and evaluation methods (pre-quant, post-quant, sliding).

Architecture

13-layer byte-level autoregressive transformer (vocab=260, no BPE/SentencePiece). The primary objective is standard next-byte CE loss. A lightweight JEPA module predicts chunk-level latent representations as an auxiliary signal (λ_max=0.001), adding 0.01 BPB over pure AR. Chunk prediction inspired by LeWM.

Component Detail
Backbone 13L, dim=512, 8H/4KV GQA, MLP 2x, LeakyReLU(0.5)², U-Net skips
JEPA projector Linear(512,256) → RMSNorm → SiLU → Linear(256,256)
JEPA predictor 2-layer MLP, 256d, causal shift with learned start token
JEPA injection Linear(256,512), zero-init, adds predicted latents to residual stream
SIGReg Epps-Pulley, 256 projections, 17 knots — prevents latent collapse
Training Phased: 30% pure AR, 50% AR+JEPA ramp, 20% pure AR
Loss CE + λ(MSE_pred + 0.02·SIGReg), λ ramps 0→0.001

Carried from our sp1024 stack: Muon+WD=0.04, EMA 0.997, XSA last 4 layers, Partial RoPE 16 dims, LN Scale, SmearGate, BigramHash(4096,32), OrthoInit+muP, int6+zstd-22, FA3.

Results

Metric Value
Pre-quant val_bpb 1.2293
Int6 roundtrip val_bpb 1.2184
Int6 sliding val_bpb (s512) 1.1905
Steps completed 9,000
Step time 63ms
Model params 24,625,001
Artifact size 14,182,907 bytes

Reproducibility (3 seeds)

Seed Steps Sliding s512 Artifact
2025 9,000 1.1903 14,369,791
42 9,000 1.1905 14,182,907
7 9,000 1.1915 14,445,175

Mean: 1.1908 | Range: 0.0012 | Submitted: seed 2025

Run command

NUM_LAYERS=13 VOCAB_SIZE=260 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2.0 \
TRAIN_SEQ_LEN=4096 TRAIN_BATCH_TOKENS=393216 BIGRAM_VOCAB_SIZE=4096 BIGRAM_DIM=32 \
XSA_LAST_N=4 EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 \
ROPE_DIMS=16 LN_SCALE=1 MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500 \
WARMDOWN_ITERS=3000 ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 \
EVAL_STRIDE=512 JEPA_CHUNK_SIZE=8 JEPA_LATENT_DIM=256 JEPA_PROJ_HIDDEN=256 \
JEPA_LAMBDA_MAX=0.001 JEPA_SIGREG_WEIGHT=0.02 JEPA_LR=0.001 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

Data

Uses fineweb10B_byte260 — raw UTF-8 bytes tokenized with byte_offset=4 (IDs 4-259 = byte values 0-255). Converted from sp1024 shards via lookup table decode. No SentencePiece dependency at runtime. BPB = loss / ln(2), no tokenizer correction needed.

@MatoTeziTanka
Copy link
Copy Markdown

Community Review — Non-record: Byte-level transformer + JEPA auxiliary loss (val_bpb: 1.1903)

BPB: 1.1903 | Compliance: LOOKS CLEAN — pure-neural submission, no TTT/SLOT/n-gram-cache

What I found in the code (head SHA 263276327b05, file records/track_10min_16mb/2026-03-26_ByteJEPA_Compression_1.1903/train_gpt.py):

Static code review found no TTT adaptation function, no SLOT optimization loop, no n-gram-cache class, and no pre-quant val-token fine-tune. The eval path uses the standard sliding-window stride-64 pattern. The submission is a pure-neural architecture iteration on the standard SP1024/SP4096/SP8192 baseline.

CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.03s, dim=512, layers=13, vocab=260, code=71203 B, SMOKE_TEST_PASS

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending the usual record-track checks (3-seed validation, under-16MB artifact cap, ≤600s train + ≤600s eval on 8×H100 SXM). No compliance flags from the classification pass — this looks like a clean pure-neural iteration on the standard baseline.

Auto-classification caveat: this review was drafted by the AST-based classifier. If there's a non-standard eval mechanism (logit postprocessing, hedge mixing, etc.) that I missed because it's factored into a helper file or a non-standard function name, please flag it and I'll re-run the audit manually.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.03s, dim=512, layers=13, vocab=260, code=71203 B, SMOKE_TEST_PASS. Classification via deterministic AST-based classify_prs.py (pattern bank derived from ~65 manually-reviewed PRs earlier in the 2026-04-11 sweep). This review was auto-drafted from a template and spot-checked before posting — if the template misread your code, please call it out so I can iterate the classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants