|
| 1 | +# JEPA-LM: When Synthetic Success Doesn't Transfer to Real Language |
| 2 | + |
| 3 | +**Non-Record Submission (Research Contribution / Negative Result)** |
| 4 | +**Author:** Himanshu Dongre ([@himanshudongre](https://github.com/himanshudongre)) |
| 5 | +**Compute:** $0 (all experiments on Mac Mini M4, MPS backend) |
| 6 | +**Status:** Negative result -- JEPA provides no meaningful benefit for real language modeling at this scale |
| 7 | + |
| 8 | +--- |
| 9 | + |
| 10 | +## The Short Version |
| 11 | + |
| 12 | +I implemented JEPA (Joint Embedding Predictive Architecture) as a training-time auxiliary loss for language modeling. On synthetic Markov chain data, JEPA showed a **dramatic -19.5% cross-entropy improvement** over a standard Transformer baseline. On real English text (Project Gutenberg), the improvement collapsed to **-0.24%** with **+40% throughput overhead** -- a massively net-negative result. |
| 13 | + |
| 14 | +This is a cautionary tale about validation methodology: synthetic benchmarks can be wildly misleading. The repetitive statistical patterns in Markov chains are exactly what JEPA's representation prediction excels at, but natural language doesn't have those patterns at the scale where JEPA's overhead is justified. |
| 15 | + |
| 16 | +I'm submitting this because JEPA is on OpenAI's "Requests for PRs" wishlist, and negative results with clear explanations are often more valuable than marginal positive ones. If you're considering JEPA for Parameter Golf, this document will explain why it doesn't work and save you from making the same mistake. |
| 17 | + |
| 18 | +--- |
| 19 | + |
| 20 | +## Table of Contents |
| 21 | + |
| 22 | +1. [Motivation](#motivation) |
| 23 | +2. [How JEPA-LM Works](#how-jepa-lm-works) |
| 24 | +3. [Synthetic Data Results (Promising)](#synthetic-data-results-promising) |
| 25 | +4. [Real Text Results (Disappointing)](#real-text-results-disappointing) |
| 26 | +5. [Why the Gap?](#why-the-gap) |
| 27 | +6. [Could JEPA Work With Changes?](#could-jepa-work-with-changes) |
| 28 | +7. [Connection to SSM Work](#connection-to-ssm-work) |
| 29 | +8. [Reproducing These Results](#reproducing-these-results) |
| 30 | + |
| 31 | +--- |
| 32 | + |
| 33 | +## Motivation |
| 34 | + |
| 35 | +After my two-pass n-gram rescoring PR (#846) was closed in the enforcement sweep (Issue #677), I committed to pursuing pure architectural innovation -- no eval-time tricks. I wanted training-time techniques that produce better model weights without modifying the evaluation procedure. |
| 36 | + |
| 37 | +JEPA was appealing for three reasons: |
| 38 | + |
| 39 | +1. **Zero eval-time overhead.** The JEPA target encoder is ephemeral -- used only during training, never stored in the 16MB artifact. At eval time, the model is a standard Transformer. This sidesteps the throughput trap that kills novel architectures (see PR #831's analysis). |
| 40 | + |
| 41 | +2. **Richer gradient signal.** Instead of just predicting next-token distributions (a sparse signal from a 1024-way vocabulary), JEPA predicts dense representations in a learned latent space. In theory, this provides a more informative training signal. |
| 42 | + |
| 43 | +3. **OpenAI asked for it.** JEPA is explicitly listed in the "Requests for PRs" section of the README. |
| 44 | + |
| 45 | +## How JEPA-LM Works |
| 46 | + |
| 47 | +### Standard LM Training |
| 48 | +``` |
| 49 | +Input tokens -> Encoder -> LM Head -> Cross-entropy loss vs. true next token |
| 50 | +``` |
| 51 | + |
| 52 | +### JEPA-LM Training |
| 53 | +``` |
| 54 | +Input tokens -> Online Encoder -> LM Head -> CE loss (standard) |
| 55 | + | |
| 56 | + +-> Predictor -> predicted representation of next token |
| 57 | + | |
| 58 | + v |
| 59 | +Input tokens -> Target Encoder (EMA) -> target representation of next token |
| 60 | + | |
| 61 | + v |
| 62 | + JEPA loss = ||predicted - target||^2 |
| 63 | + | |
| 64 | + v |
| 65 | + Total loss = CE + lambda * JEPA |
| 66 | +``` |
| 67 | + |
| 68 | +The target encoder is an exponential moving average (EMA) of the online encoder, updated every step. It provides stable prediction targets without collapse (no gradient flows through the target encoder). |
| 69 | + |
| 70 | +### Key Design Choice: Training-Only |
| 71 | + |
| 72 | +The target encoder, predictor, and JEPA loss are **completely discarded after training**. The exported model is a standard Transformer. The hypothesis is that JEPA's auxiliary loss produces better internal representations that persist in the trained weights. |
| 73 | + |
| 74 | +### Parameter Overhead During Training |
| 75 | + |
| 76 | +| Component | Params | Note | |
| 77 | +|-----------|--------|------| |
| 78 | +| Online encoder | Same as baseline | Standard model | |
| 79 | +| Target encoder | Same as baseline | EMA copy (not stored) | |
| 80 | +| Predictor | ~dim x dim | Small MLP | |
| 81 | +| **Training overhead** | **~2x memory** | Target encoder is full copy | |
| 82 | + |
| 83 | +At eval/export time: **zero overhead**. The target encoder doesn't exist. |
| 84 | + |
| 85 | +## Synthetic Data Results (Promising) |
| 86 | + |
| 87 | +### Setup |
| 88 | +- Synthetic Markov chain text (controlled statistical patterns) |
| 89 | +- dim=192, 6 layers, 3000 training steps |
| 90 | +- Mac Mini M4, MPS backend |
| 91 | + |
| 92 | +### Results |
| 93 | + |
| 94 | +| Model | Final CE | ms/step | CE vs Baseline | |
| 95 | +|-------|----------|---------|----------------| |
| 96 | +| Pure CE (Baseline) | 0.8031 | 210-590ms | -- | |
| 97 | +| **JEPA-LM Hybrid** | **0.6466** | 750-812ms | **-19.5%** | |
| 98 | +| JEPA-LM + MoD | 0.7644 | 454-888ms | -5% | |
| 99 | + |
| 100 | +The JEPA-LM hybrid showed: |
| 101 | +- **-19.5% cross-entropy improvement** over pure CE training |
| 102 | +- Slower convergence for the first ~1400 steps, then dramatic improvement |
| 103 | +- The crossover point suggested JEPA needs time to learn useful representations before they benefit the LM head |
| 104 | + |
| 105 | +I was genuinely excited. A 19.5% improvement, even with throughput overhead, seemed like it would easily translate to a BPB improvement at full scale. |
| 106 | + |
| 107 | +It didn't. |
| 108 | + |
| 109 | +## Real Text Results (Disappointing) |
| 110 | + |
| 111 | +### Setup |
| 112 | +- 4 Project Gutenberg books (1.9MB of real English text) |
| 113 | +- Same architecture: dim=192, 6 layers, 2000 steps |
| 114 | +- Same JEPA configuration that showed -19.5% on synthetic data |
| 115 | + |
| 116 | +### Results |
| 117 | + |
| 118 | +| Model | Eval CE | ms/step | CE vs Baseline | Net Impact | |
| 119 | +|-------|---------|---------|----------------|------------| |
| 120 | +| Pure CE (Baseline) | 1.2779 | 210.6ms | -- | -- | |
| 121 | +| **JEPA-LM** | **1.2748** | **294.4ms** | **-0.24%** | **-28.2%** | |
| 122 | + |
| 123 | +- Cross-entropy improvement: **virtually zero** (-0.24%) |
| 124 | +- Throughput penalty: **+39.8%** (the target encoder EMA is expensive) |
| 125 | +- Net competition impact: **strongly negative** (-28.2% after accounting for fewer training steps) |
| 126 | + |
| 127 | +### The Reality Check |
| 128 | + |
| 129 | +At competition scale (600s on 8xH100), the +40% throughput overhead means ~1500 fewer training steps. A -0.24% quality improvement cannot compensate for losing 30% of your training budget. |
| 130 | + |
| 131 | +## Why the Gap? |
| 132 | + |
| 133 | +### Markov Chains Have Exploitable Structure |
| 134 | + |
| 135 | +Synthetic Markov chain data has **simple, repetitive statistical patterns** -- transition probabilities between states are fixed and learnable. JEPA's representation prediction excels here because: |
| 136 | + |
| 137 | +1. The target encoder learns stable representations of these patterns quickly |
| 138 | +2. The predictor can accurately forecast what the next representation should be |
| 139 | +3. The resulting gradient signal genuinely helps the online encoder learn faster |
| 140 | + |
| 141 | +### Natural Language Doesn't |
| 142 | + |
| 143 | +Real English text has: |
| 144 | +- **Long-range dependencies** that change with context |
| 145 | +- **Semantic ambiguity** where the same prefix leads to many valid continuations |
| 146 | +- **Non-stationary statistics** across documents, genres, and topics |
| 147 | + |
| 148 | +JEPA's representation prediction becomes nearly meaningless when the next token is genuinely unpredictable from the current representation. The predictor can't learn a useful mapping from "current representation" to "next representation" because the mapping is inherently many-to-many in natural language. |
| 149 | + |
| 150 | +### The Overhead Wasn't Worth It |
| 151 | + |
| 152 | +Even if JEPA provided a marginal quality improvement, the training-time overhead is fundamental: |
| 153 | +- Target encoder EMA update: O(params) per step |
| 154 | +- Forward pass through target encoder: same cost as the main model |
| 155 | +- Predictor forward + loss: small but nonzero |
| 156 | + |
| 157 | +In Parameter Golf, every millisecond costs ~7 training steps. At +40% overhead, JEPA would need to improve per-step learning by 40% just to break even. That's a much higher bar than the -0.24% it achieved. |
| 158 | + |
| 159 | +## Could JEPA Work With Changes? |
| 160 | + |
| 161 | +### Maybe, But Unlikely at This Scale |
| 162 | + |
| 163 | +Possible improvements: |
| 164 | +1. **Cheaper target encoder:** Only EMA-update a subset of layers. Reduces overhead but also reduces the signal quality. |
| 165 | +2. **Larger scale:** At dim=512+, the representation space is richer, and JEPA predictions might be more useful. But the overhead also grows. |
| 166 | +3. **Different prediction targets:** Instead of next-token representation, predict chunk-level or multi-token representations. More stable targets, potentially more useful for language. |
| 167 | +4. **Domain-specific fine-tuning:** JEPA might work better on highly structured text (code, math) where next-token prediction is more deterministic. |
| 168 | + |
| 169 | +### My Honest Assessment |
| 170 | + |
| 171 | +JEPA is a beautiful idea for vision (where spatial structure makes representation prediction natural). For language modeling at the scale of Parameter Golf, the cost-benefit ratio is wrong. The auxiliary loss is expensive and the signal is too weak for natural language at dim=192-512. |
| 172 | + |
| 173 | +If someone wants to push this further, I'd suggest trying at dim=768+ with a much cheaper target encoder (EMA only the last 2 layers). But I wouldn't bet on it. |
| 174 | + |
| 175 | +## Connection to SSM Work |
| 176 | + |
| 177 | +After JEPA failed, I pivoted to S4D-Lin State Space Models (see companion PR). The key lesson from JEPA informed my SSM approach: **always validate on real text first**. I ran the SSM through local real-text validation at seq_len=512 before spending any GPU credits. |
| 178 | + |
| 179 | +Ironically, the SSM local tests also turned out to be misleading (for different reasons -- the quality advantage at dim=192 didn't hold at dim=512). The full story is in the SSM PR. |
| 180 | + |
| 181 | +## Reproducing These Results |
| 182 | + |
| 183 | +### Local Experiments |
| 184 | + |
| 185 | +All experiments run on Mac Mini M4 (MPS backend), no GPU required. |
| 186 | + |
| 187 | +```bash |
| 188 | +# Synthetic data test (Markov chains) |
| 189 | +python3 -u jepa_mod_experiment.py |
| 190 | + |
| 191 | +# Real text test (requires text_corpus.txt from Project Gutenberg) |
| 192 | +python3 -u jepa_real_text_test.py |
| 193 | +``` |
| 194 | + |
| 195 | +### Files |
| 196 | + |
| 197 | +- `jepa_mod_experiment.py`: Full JEPA-LM implementation with CE, JEPA hybrid, and JEPA+MoD variants |
| 198 | +- `jepa_real_text_test.py`: Real text validation script |
| 199 | +- `jepa_mod_results.json`: Synthetic data benchmark results |
| 200 | +- `jepa_real_text_results.json`: Real text benchmark results |
| 201 | + |
| 202 | +--- |
| 203 | + |
| 204 | +*This submission checks off "JEPA" from the Requests for PRs wishlist. Sometimes the most valuable research contribution is showing definitively why a promising idea doesn't work in a specific setting.* |
0 commit comments