Skip to content

Non-record: LLM-JEPA — Joint Embedding Prediction (val_bpb 2.2020)#1196

Open
dentity007 wants to merge 3 commits intoopenai:mainfrom
NathanMaine:research/llm-jepa
Open

Non-record: LLM-JEPA — Joint Embedding Prediction (val_bpb 2.2020)#1196
dentity007 wants to merge 3 commits intoopenai:mainfrom
NathanMaine:research/llm-jepa

Conversation

@dentity007
Copy link
Copy Markdown

@dentity007 dentity007 commented Mar 31, 2026

Non-record: LLM-JEPA - Joint Embedding Prediction for Language Modeling

val_bpb: 2.2020 | 1x RTX 5090 Ada 16GB, 180s wallclock | sp1024

Implements OpenAI's requested "JEPA" research direction.

Architecture

  • Adds a Joint Embedding Predictive Architecture (JEPA) prediction head alongside the standard AR next-token prediction loss
  • Context embeddings predict target span embeddings via a lightweight MLP predictor
  • Stop-gradient on targets prevents representation collapse (BYOL-style)
  • JEPA predictor is stripped from the exported model (only AR head used for eval)
  • Base config: 9 layers, d=512, 8 heads, sp1024 vocab, MLP 2x
  • Batch size halved to fit JEPA auxiliary computation in GPU memory

Results

Metric Value
val_bpb (final) 2.2020
Training time 180s (1x RTX 5090)
JEPA loss weight 0.3 (vs 0.7 AR)

Key Findings

  1. JEPA as auxiliary loss does not help AR performance in this regime. The 2.2020 BPB is nearly identical to random adapters (2.2017), suggesting the JEPA objective is not contributing meaningful signal to the AR task in 180s of training.

  2. The halved batch size hurts significantly. To fit JEPA's auxiliary computation, batch size was halved. This alone likely accounts for much of the BPB regression from baseline. A fairer comparison would use gradient accumulation.

  3. JEPA predictor converges to non-trivial representations. The JEPA loss decreases during training, indicating the predictor learns something. But this representation knowledge does not transfer to the AR objective within the training time.

  4. The prediction head adds minimal parameters. The lightweight MLP predictor is small and stripped from the export. It does not affect artifact size.

Comparison to Naive Baseline

Naive Baseline LLM-JEPA
Loss AR only 70% AR + 30% JEPA
val_bpb 1.2244 2.2020
Batch size Full Half (memory constraint)

Reproduction

pip install sentencepiece brotli
python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 10
MAX_WALLCLOCK_SECONDS=180 python3 train_gpt_jepa.py

Discussion

JEPA's strength is learning abstract representations that go beyond next-token prediction. In the parameter-golf setting, however, the eval metric IS next-token prediction (BPB), so the JEPA objective's benefits are indirect at best. The approach might show stronger results with: (a) longer training where JEPA representations have time to influence the shared backbone, (b) larger batch sizes eliminating the memory constraint, or (c) a curriculum that pre-trains with JEPA then switches to pure AR.

Would welcome suggestions on JEPA variants that might transfer better to AR evaluation.

Credits

Script: train_gpt_jepa.py
Implements OpenAI's requested "JEPA" direction from the README.

@dentity007 dentity007 closed this Apr 1, 2026
@dentity007 dentity007 reopened this Apr 1, 2026
@dentity007
Copy link
Copy Markdown
Author

Research Expansion: Ablation Results

Ran an overnight ablation study on DGX Spark GB10 to expand on this submission. 200 training steps, sp1024, no torch.compile.

Results

Run JEPA Weight val_bpb ms/step
JEPA-1 10% 2.2323 498
JEPA-2 30% 2.2322 496
JEPA-3 50% 2.2322 497

Finding

Three different JEPA weights (10%, 30%, 50%) produce identical BPB to 4 decimal places. JEPA as a concurrent auxiliary loss has zero effect on AR val_bpb at 200 steps. The JEPA predictor learns something (its loss decreases during training), but that knowledge does not transfer to the AR objective within 200 steps.

Possible next steps: JEPA as a pre-training stage curriculum (JEPA-only first, then AR-only), or JEPA on final layers only (localize to representation layers rather than backbone). Both untested here.

Full raw data and logs: https://gist.github.com/dentity007/324ac35505c27acd18e7ffb468f4fa08

@MatoTeziTanka
Copy link
Copy Markdown

Community Review — Non-record: LLM-JEPA — Joint Embedding Prediction (val_bpb 2.2020)

Compliance: LOOKS CLEAN — pure-neural submission, no TTT/SLOT/n-gram-cache

PR #1196 — LLM-JEPA: Joint Embedding Predictive Architecture Head SHA: 4914d51 Files changed: - APPROACH.md - records/track_non_record_16mb/2026-03-31_LLMJEPA_JointEmbeddingPrediction/README.md - records/track_non_record_16mb/2026-03-31_LLMJEPA_JointEmbeddingPrediction/submission.json - records/track_non_record_16mb/2026-03-31_LLMJEPA_JointEmbeddingPrediction/train_gpt.py --- ### Check 1: ILLEGAL n-gram family bug — NOT PRESENT No n-gram, BigramHash, or XOR-target-into-hash-key constructs anywhere in the file. Zero matches for ngram, n_gram, BigramHash, xor. CLEAR. ### Check 2: ILLEGAL Pre-Quant TTT (multi-epoch on val_tokens) — NOT PRESENT val_tokens is only used in load_validation_tokens() (lines 217–226), eval_val() (lines 229–288), and at the two eval_val call sites in the main loop (lines 1112–1123 and lines 1273–1284). val_tokens is never passed to the optimizer, never backpropagated through, and never iterated in a multi-epoch loop. The eval_val function runs under torch.inference_mode() with model.eval() (lines 259–277). CLEAR. ### Check 3: LEGAL score-first TTT — NOT PRESENT No TTT of any kind exists. There is no is_last_chunk guard, no score-first gating logic. N/A. ### Check 4: HOLD scored-region SLOT — NOT PRESENT No scored-region slot manipulation or HOLD pattern detected. N/A. ### Check 5: Pure neural — CONFIRMED The submission is a standard autoregressive GPT with a JEPA auxiliary training objective. The JEPA head (JEPAPredictor, lines 663–675) is a lightweight 2-layer MLP that predicts future span embeddings from context hidden states during training only. It is explicitly stripped from the exported model at serialization time (lines 1231–1235). Training uses only train-split data from DistributedTokenLoader (lines 488–505). Validation is passive inference-only. No n-grams, no TTT, no external data, no score-signal exploitation. Conclusion: Architecturally clean pure-neural entry. JEPA auxiliary loss is a...

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending the usual record-track checks (3-seed validation, under-16MB artifact cap, ≤600s train + ≤600s eval on 8×H100 SXM). No compliance flags from the audit — this looks like a clean pure-neural submission.


Reviewed by @MatoTeziTankaThe Agora. Compliance audit via LLM agent (Sonnet) reviewing full train_gpt.py source, cross-checked against deterministic AST classifier. If this review misread your code, please call it out so I can re-audit manually.

@dentity007
Copy link
Copy Markdown
Author

Thanks for the review. Confirmed: the JEPA predictor head is stripped before serialization (lines 1231-1235 in the version you audited) so it adds no artifact bytes, and the auxiliary loss only flows through train data. The eval pass is pure AR.

Research finding that may be useful context: my Spark ablation tested three JEPA loss weights (10%, 30%, 50%) and they produced identical BPB to 4 decimal places. The JEPA auxiliary appears to be neutral at this training duration. The predictor does learn (its loss curve decreases) but that knowledge does not transfer to the AR objective in 200 steps. This might change with a JEPA-first pre-training curriculum followed by pure AR finetuning, which is a direction I haven't explored yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants