Non-record: LLM-JEPA — Joint Embedding Prediction (val_bpb 2.2020)#1196
Non-record: LLM-JEPA — Joint Embedding Prediction (val_bpb 2.2020)#1196dentity007 wants to merge 3 commits intoopenai:mainfrom
Conversation
…er optimization, and SSM exploration
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Research Expansion: Ablation ResultsRan an overnight ablation study on DGX Spark GB10 to expand on this submission. 200 training steps, sp1024, no torch.compile. Results
FindingThree different JEPA weights (10%, 30%, 50%) produce identical BPB to 4 decimal places. JEPA as a concurrent auxiliary loss has zero effect on AR val_bpb at 200 steps. The JEPA predictor learns something (its loss decreases during training), but that knowledge does not transfer to the AR objective within 200 steps. Possible next steps: JEPA as a pre-training stage curriculum (JEPA-only first, then AR-only), or JEPA on final layers only (localize to representation layers rather than backbone). Both untested here. Full raw data and logs: https://gist.github.com/dentity007/324ac35505c27acd18e7ffb468f4fa08 |
Community Review — Non-record: LLM-JEPA — Joint Embedding Prediction (val_bpb 2.2020)Compliance: LOOKS CLEAN — pure-neural submission, no TTT/SLOT/n-gram-cache PR #1196 — LLM-JEPA: Joint Embedding Predictive Architecture Head SHA: 4914d51 Files changed: -
|
|
Thanks for the review. Confirmed: the JEPA predictor head is stripped before serialization (lines 1231-1235 in the version you audited) so it adds no artifact bytes, and the auxiliary loss only flows through train data. The eval pass is pure AR. Research finding that may be useful context: my Spark ablation tested three JEPA loss weights (10%, 30%, 50%) and they produced identical BPB to 4 decimal places. The JEPA auxiliary appears to be neutral at this training duration. The predictor does learn (its loss curve decreases) but that knowledge does not transfer to the AR objective in 200 steps. This might change with a JEPA-first pre-training curriculum followed by pure AR finetuning, which is a direction I haven't explored yet. |
Non-record: LLM-JEPA - Joint Embedding Prediction for Language Modeling
val_bpb: 2.2020 | 1x RTX 5090 Ada 16GB, 180s wallclock | sp1024
Implements OpenAI's requested "JEPA" research direction.
Architecture
Results
Key Findings
JEPA as auxiliary loss does not help AR performance in this regime. The 2.2020 BPB is nearly identical to random adapters (2.2017), suggesting the JEPA objective is not contributing meaningful signal to the AR task in 180s of training.
The halved batch size hurts significantly. To fit JEPA's auxiliary computation, batch size was halved. This alone likely accounts for much of the BPB regression from baseline. A fairer comparison would use gradient accumulation.
JEPA predictor converges to non-trivial representations. The JEPA loss decreases during training, indicating the predictor learns something. But this representation knowledge does not transfer to the AR objective within the training time.
The prediction head adds minimal parameters. The lightweight MLP predictor is small and stripped from the export. It does not affect artifact size.
Comparison to Naive Baseline
Reproduction
Discussion
JEPA's strength is learning abstract representations that go beyond next-token prediction. In the parameter-golf setting, however, the eval metric IS next-token prediction (BPB), so the JEPA objective's benefits are indirect at best. The approach might show stronger results with: (a) longer training where JEPA representations have time to influence the shared backbone, (b) larger batch sizes eliminating the memory constraint, or (c) a curriculum that pre-trains with JEPA then switches to pure AR.
Would welcome suggestions on JEPA variants that might transfer better to AR evaluation.
Credits
Script:
train_gpt_jepa.pyImplements OpenAI's requested "JEPA" direction from the README.