Skip to content

Commit 3b9e547

Browse files
Non-record: JEPA-LM — When Synthetic Success Doesn't Transfer to Real Language
Implements JEPA (Joint Embedding Predictive Architecture) as training-time auxiliary loss for language modeling. On synthetic Markov chain data, JEPA showed -19.5% cross-entropy improvement. On real English text, the improvement collapsed to -0.24% with +40% throughput overhead. Key finding: Markov chains have exploitable repetitive statistical structure that JEPA excels at, but natural language doesn't. This is a cautionary tale about synthetic benchmark validation. Checks off "JEPA" from Requests for PRs. See companion PR (S4D-Lin SSM Hybrid) for where this research led next.
1 parent 50390d6 commit 3b9e547

4 files changed

Lines changed: 1512 additions & 0 deletions

File tree

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# JEPA-LM: When Synthetic Success Doesn't Transfer to Real Language
2+
3+
**Non-Record Submission (Research Contribution / Negative Result)**
4+
**Author:** Himanshu Dongre ([@himanshudongre](https://github.com/himanshudongre))
5+
**Compute:** $0 (all experiments on Mac Mini M4, MPS backend)
6+
**Status:** Negative result -- JEPA provides no meaningful benefit for real language modeling at this scale
7+
8+
---
9+
10+
## The Short Version
11+
12+
I implemented JEPA (Joint Embedding Predictive Architecture) as a training-time auxiliary loss for language modeling. On synthetic Markov chain data, JEPA showed a **dramatic -19.5% cross-entropy improvement** over a standard Transformer baseline. On real English text (Project Gutenberg), the improvement collapsed to **-0.24%** with **+40% throughput overhead** -- a massively net-negative result.
13+
14+
This is a cautionary tale about validation methodology: synthetic benchmarks can be wildly misleading. The repetitive statistical patterns in Markov chains are exactly what JEPA's representation prediction excels at, but natural language doesn't have those patterns at the scale where JEPA's overhead is justified.
15+
16+
I'm submitting this because JEPA is on OpenAI's "Requests for PRs" wishlist, and negative results with clear explanations are often more valuable than marginal positive ones. If you're considering JEPA for Parameter Golf, this document will explain why it doesn't work and save you from making the same mistake.
17+
18+
---
19+
20+
## Table of Contents
21+
22+
1. [Motivation](#motivation)
23+
2. [How JEPA-LM Works](#how-jepa-lm-works)
24+
3. [Synthetic Data Results (Promising)](#synthetic-data-results-promising)
25+
4. [Real Text Results (Disappointing)](#real-text-results-disappointing)
26+
5. [Why the Gap?](#why-the-gap)
27+
6. [Could JEPA Work With Changes?](#could-jepa-work-with-changes)
28+
7. [Connection to SSM Work](#connection-to-ssm-work)
29+
8. [Reproducing These Results](#reproducing-these-results)
30+
31+
---
32+
33+
## Motivation
34+
35+
After my two-pass n-gram rescoring PR (#846) was closed in the enforcement sweep (Issue #677), I committed to pursuing pure architectural innovation -- no eval-time tricks. I wanted training-time techniques that produce better model weights without modifying the evaluation procedure.
36+
37+
JEPA was appealing for three reasons:
38+
39+
1. **Zero eval-time overhead.** The JEPA target encoder is ephemeral -- used only during training, never stored in the 16MB artifact. At eval time, the model is a standard Transformer. This sidesteps the throughput trap that kills novel architectures (see PR #831's analysis).
40+
41+
2. **Richer gradient signal.** Instead of just predicting next-token distributions (a sparse signal from a 1024-way vocabulary), JEPA predicts dense representations in a learned latent space. In theory, this provides a more informative training signal.
42+
43+
3. **OpenAI asked for it.** JEPA is explicitly listed in the "Requests for PRs" section of the README.
44+
45+
## How JEPA-LM Works
46+
47+
### Standard LM Training
48+
```
49+
Input tokens -> Encoder -> LM Head -> Cross-entropy loss vs. true next token
50+
```
51+
52+
### JEPA-LM Training
53+
```
54+
Input tokens -> Online Encoder -> LM Head -> CE loss (standard)
55+
|
56+
+-> Predictor -> predicted representation of next token
57+
|
58+
v
59+
Input tokens -> Target Encoder (EMA) -> target representation of next token
60+
|
61+
v
62+
JEPA loss = ||predicted - target||^2
63+
|
64+
v
65+
Total loss = CE + lambda * JEPA
66+
```
67+
68+
The target encoder is an exponential moving average (EMA) of the online encoder, updated every step. It provides stable prediction targets without collapse (no gradient flows through the target encoder).
69+
70+
### Key Design Choice: Training-Only
71+
72+
The target encoder, predictor, and JEPA loss are **completely discarded after training**. The exported model is a standard Transformer. The hypothesis is that JEPA's auxiliary loss produces better internal representations that persist in the trained weights.
73+
74+
### Parameter Overhead During Training
75+
76+
| Component | Params | Note |
77+
|-----------|--------|------|
78+
| Online encoder | Same as baseline | Standard model |
79+
| Target encoder | Same as baseline | EMA copy (not stored) |
80+
| Predictor | ~dim x dim | Small MLP |
81+
| **Training overhead** | **~2x memory** | Target encoder is full copy |
82+
83+
At eval/export time: **zero overhead**. The target encoder doesn't exist.
84+
85+
## Synthetic Data Results (Promising)
86+
87+
### Setup
88+
- Synthetic Markov chain text (controlled statistical patterns)
89+
- dim=192, 6 layers, 3000 training steps
90+
- Mac Mini M4, MPS backend
91+
92+
### Results
93+
94+
| Model | Final CE | ms/step | CE vs Baseline |
95+
|-------|----------|---------|----------------|
96+
| Pure CE (Baseline) | 0.8031 | 210-590ms | -- |
97+
| **JEPA-LM Hybrid** | **0.6466** | 750-812ms | **-19.5%** |
98+
| JEPA-LM + MoD | 0.7644 | 454-888ms | -5% |
99+
100+
The JEPA-LM hybrid showed:
101+
- **-19.5% cross-entropy improvement** over pure CE training
102+
- Slower convergence for the first ~1400 steps, then dramatic improvement
103+
- The crossover point suggested JEPA needs time to learn useful representations before they benefit the LM head
104+
105+
I was genuinely excited. A 19.5% improvement, even with throughput overhead, seemed like it would easily translate to a BPB improvement at full scale.
106+
107+
It didn't.
108+
109+
## Real Text Results (Disappointing)
110+
111+
### Setup
112+
- 4 Project Gutenberg books (1.9MB of real English text)
113+
- Same architecture: dim=192, 6 layers, 2000 steps
114+
- Same JEPA configuration that showed -19.5% on synthetic data
115+
116+
### Results
117+
118+
| Model | Eval CE | ms/step | CE vs Baseline | Net Impact |
119+
|-------|---------|---------|----------------|------------|
120+
| Pure CE (Baseline) | 1.2779 | 210.6ms | -- | -- |
121+
| **JEPA-LM** | **1.2748** | **294.4ms** | **-0.24%** | **-28.2%** |
122+
123+
- Cross-entropy improvement: **virtually zero** (-0.24%)
124+
- Throughput penalty: **+39.8%** (the target encoder EMA is expensive)
125+
- Net competition impact: **strongly negative** (-28.2% after accounting for fewer training steps)
126+
127+
### The Reality Check
128+
129+
At competition scale (600s on 8xH100), the +40% throughput overhead means ~1500 fewer training steps. A -0.24% quality improvement cannot compensate for losing 30% of your training budget.
130+
131+
## Why the Gap?
132+
133+
### Markov Chains Have Exploitable Structure
134+
135+
Synthetic Markov chain data has **simple, repetitive statistical patterns** -- transition probabilities between states are fixed and learnable. JEPA's representation prediction excels here because:
136+
137+
1. The target encoder learns stable representations of these patterns quickly
138+
2. The predictor can accurately forecast what the next representation should be
139+
3. The resulting gradient signal genuinely helps the online encoder learn faster
140+
141+
### Natural Language Doesn't
142+
143+
Real English text has:
144+
- **Long-range dependencies** that change with context
145+
- **Semantic ambiguity** where the same prefix leads to many valid continuations
146+
- **Non-stationary statistics** across documents, genres, and topics
147+
148+
JEPA's representation prediction becomes nearly meaningless when the next token is genuinely unpredictable from the current representation. The predictor can't learn a useful mapping from "current representation" to "next representation" because the mapping is inherently many-to-many in natural language.
149+
150+
### The Overhead Wasn't Worth It
151+
152+
Even if JEPA provided a marginal quality improvement, the training-time overhead is fundamental:
153+
- Target encoder EMA update: O(params) per step
154+
- Forward pass through target encoder: same cost as the main model
155+
- Predictor forward + loss: small but nonzero
156+
157+
In Parameter Golf, every millisecond costs ~7 training steps. At +40% overhead, JEPA would need to improve per-step learning by 40% just to break even. That's a much higher bar than the -0.24% it achieved.
158+
159+
## Could JEPA Work With Changes?
160+
161+
### Maybe, But Unlikely at This Scale
162+
163+
Possible improvements:
164+
1. **Cheaper target encoder:** Only EMA-update a subset of layers. Reduces overhead but also reduces the signal quality.
165+
2. **Larger scale:** At dim=512+, the representation space is richer, and JEPA predictions might be more useful. But the overhead also grows.
166+
3. **Different prediction targets:** Instead of next-token representation, predict chunk-level or multi-token representations. More stable targets, potentially more useful for language.
167+
4. **Domain-specific fine-tuning:** JEPA might work better on highly structured text (code, math) where next-token prediction is more deterministic.
168+
169+
### My Honest Assessment
170+
171+
JEPA is a beautiful idea for vision (where spatial structure makes representation prediction natural). For language modeling at the scale of Parameter Golf, the cost-benefit ratio is wrong. The auxiliary loss is expensive and the signal is too weak for natural language at dim=192-512.
172+
173+
If someone wants to push this further, I'd suggest trying at dim=768+ with a much cheaper target encoder (EMA only the last 2 layers). But I wouldn't bet on it.
174+
175+
## Connection to SSM Work
176+
177+
After JEPA failed, I pivoted to S4D-Lin State Space Models (see companion PR). The key lesson from JEPA informed my SSM approach: **always validate on real text first**. I ran the SSM through local real-text validation at seq_len=512 before spending any GPU credits.
178+
179+
Ironically, the SSM local tests also turned out to be misleading (for different reasons -- the quality advantage at dim=192 didn't hold at dim=512). The full story is in the SSM PR.
180+
181+
## Reproducing These Results
182+
183+
### Local Experiments
184+
185+
All experiments run on Mac Mini M4 (MPS backend), no GPU required.
186+
187+
```bash
188+
# Synthetic data test (Markov chains)
189+
python3 -u jepa_mod_experiment.py
190+
191+
# Real text test (requires text_corpus.txt from Project Gutenberg)
192+
python3 -u jepa_real_text_test.py
193+
```
194+
195+
### Files
196+
197+
- `jepa_mod_experiment.py`: Full JEPA-LM implementation with CE, JEPA hybrid, and JEPA+MoD variants
198+
- `jepa_real_text_test.py`: Real text validation script
199+
- `jepa_mod_results.json`: Synthetic data benchmark results
200+
- `jepa_real_text_results.json`: Real text benchmark results
201+
202+
---
203+
204+
*This submission checks off "JEPA" from the Requests for PRs wishlist. Sometimes the most valuable research contribution is showing definitively why a promising idea doesn't work in a specific setting.*

0 commit comments

Comments
 (0)