Skip to content
141 changes: 141 additions & 0 deletions records/track_non_record_16mb/2026-04-01_TTT_E2E_FOMAML/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Prime MLP Test-Time Training: Naive vs E2E (FOMAML)

Two studies on test-time training with prime MLP adapters. **Naive TTT gives
-0.079 BPB for free (eval-only). E2E FOMAML gives -0.097 total but costs
44% of the training budget.**

## Motivation

All 25 prior naive TTT attempts failed because they perturbed GPTQ'd int5/int6
weights. Prime MLPs are separate bf16 parameters — they don't touch GPTQ'd weights.

## Architecture

Rank-256 prime MLPs on all 11 blocks, running before the main MLP:

```
h = h + attn(norm(h))
h = h + prime_MLP(prime_norm(h)) # bf16, adapted via SGD at eval time
h = h + MLP(mlp_norm(h)) # GPTQ'd int5/int6, frozen
```

Down projection zero-init (model starts unchanged). Score-first eval is legal.

## Results

### Full-data training (40 shards, MLP 3.5x, 10K steps, 1x L40S)

| Config | val_bpb | Delta |
|--------|---------|-------|
| **Baseline (EMA, no TTT)** | **1.3696** | — |
| **TTT lr=0.1, all 11 layers** | **1.2906** | **-0.079** |

### Sweep summary (5K chunks, full-data model)

| Experiment | val_bpb | Delta |
|------------|---------|-------|
| Baseline | 1.3696 | — |
| lr=0.03 | 1.3670 | -0.003 |
| lr=0.1 | 1.3636 | -0.006 |
| lr=0.3 | 1.3601 | -0.010 |
| lr=1.0 | 1.3550 | -0.015 |
| rank=64 (3 layers) | 1.3661 | -0.004 |
| rank=512 (3 layers) | 1.3638 | -0.006 |
| layer=[10] only | 1.3669 | -0.003 |
| layer=[6..10] | 1.3609 | -0.009 |
| **layer=all (11)** | **1.3242** | **-0.045** |
| momentum=0.9 | 1.3574 | -0.012 |

### Key findings (Study 1)

1. **Layer count >> rank.** All 11 layers (-0.045) crushes rank=512 on 3 layers (-0.006)
2. **Higher LR is better** up to 1.0 (still improving, ceiling not found)
3. **Full eval compounds** — 60K chunks gives ~1.8x the 5K-chunk delta
4. **Effect scales with model quality** — -0.079 on strong model vs -0.073 on weak
5. **Momentum=0.9 helps** (+2x at same LR on 3 layers)

### 1-shard control (earlier experiment)

| Config | Baseline | TTT | Delta |
|--------|----------|-----|-------|
| 1 shard, 7200 steps | 1.5019 | 1.4288 | -0.073 |
| **40 shards, 10K steps** | **1.3696** | **1.2906** | **-0.079** |

## Artifact size for PR 1105

| Config | Prime MLP size | Fits 16 MB? |
|--------|---------------|-------------|
| rank=256, all 11 layers | 5.75 MB | No |
| rank=64, all 11 layers | 1.41 MB | Yes |
| rank=32, all 11 layers | 0.70 MB | Yes |

rank=64 all-layers fits and rank barely matters vs layer count.

---

## Study 2: E2E TTT (FOMAML Meta-Learning)

### Method

Phase 2 FOMAML joint training on the strong 40-shard checkpoint. Base model at
0.0003 LR, prime MLPs at 0.003 LR. Inner loop: K=1 SGD step on prime weights.
Outer loop: both base and prime get gradients. 3000 steps.

### Results

| Stage | val_bpb | Delta vs orig baseline |
|-------|---------|------------------------|
| Original baseline (no prime MLPs) | 1.3696 | — |
| FOMAML baseline (prime at zero) | 1.5185 | — |
| **Post-FOMAML (no TTT)** | **1.2588** | **-0.111** |
| E2E TTT (meta-learned init, lr=0.1) | 1.2656 | -0.104 |
| Naive TTT (zero-init on FOMAML model) | 1.2776 | -0.092 |

Joint FOMAML massively improves the model even without TTT (-0.260 from FOMAML
baseline). But TTT on top of FOMAML **slightly hurts** — the meta-learned init
is already tuned and SGD overshoots.

### TTT LR sweep on FOMAML model (5K chunks)

| Config | val_bpb | Delta vs FOMAML no-TTT |
|--------|---------|------------------------|
| FOMAML, no TTT | 1.2732 | — |
| + TTT lr=0.001 | 1.2731 | -0.000 |
| + TTT lr=0.01 | 1.2726 | -0.001 |
| + TTT lr=0.1 | 1.2720 | -0.001 |

TTT adds only -0.001 on top of FOMAML. The meta-learning already captured the
adaptation value during training.

### Key findings (Study 2)

1. **Joint FOMAML makes training better** — -0.260 BPB from the FOMAML baseline, even standalone
2. **TTT is nearly redundant after FOMAML** — only -0.001 additional benefit
3. **The base model co-adapts** — this isn't just adapter training, the whole model improves

---

## Head-to-head: Naive TTT vs E2E FOMAML

| Approach | Baseline | Best | Total Δ | Training cost |
|----------|----------|------|---------|---------------|
| **Naive TTT (eval-only)** | 1.3696 | 1.2906 | **-0.079** | **0** |
| FOMAML + TTT | 1.3696 | 1.2720 | -0.097 | 3000 steps (~44% budget) |

**Naive TTT is the practical winner** — zero training cost, 81% of FOMAML's benefit.
FOMAML is worth it only if the 44% training budget can be absorbed.

---

## Next steps

- [ ] 8xH100 validation on actual PR 1105 model (1.1125 BPB)
- [ ] Combine lr=1.0 + all layers + momentum=0.9 (untested combination)
- [ ] rank=64 all-layers full eval (fits 16 MB budget)

## Files

- `train_ttt_e2e.py` — Model with prime MLPs + FOMAML + TTT eval
- `train_e2e_proper.py` — Proper E2E training (Phase 1 + Phase 2 joint)
- `sweep_naive_ttt.py` — Naive TTT LR/chunk/reset sweep
- `sweep_v2.py` — LR/rank/layer/momentum sweep
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Sliding window eval with TTT on sp4608 8xH100 model.

Matches PR 1105 eval: stride=64, seq_len=2048, BPB with sentencepiece byte counting.
TTT: score window, record BPB for scored tokens, then SGD update prime MLPs.
"""
from eval_sp4608_ttt import *

STRIDE = int(os.environ.get("STRIDE", 64))


def eval_sliding_ttt(model, val_tokens, device, ttt_lr, stride,
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
momentum=0.0, max_windows=0):
"""Sliding window eval with score-first TTT."""
seq_len = SEQ_LEN
total_tokens = val_tokens.numel() - 1

# Generate window starts (same as PR 1105)
window_starts = [ws for ws in range(0, total_tokens, stride)
if ws + seq_len - stride < total_tokens]
if max_windows > 0:
window_starts = window_starts[:max_windows]
total_windows = len(window_starts)

prime_params = [p for _, p in model.prime_named_params()]
optimizer = torch.optim.SGD(prime_params, lr=ttt_lr, momentum=momentum) if ttt_lr > 0 else None

loss_sum = torch.zeros((), device=device, dtype=torch.float64)
token_count = torch.zeros((), device=device, dtype=torch.float64)
byte_count = torch.zeros((), device=device, dtype=torch.float64)
t0 = time.perf_counter()

model.eval()
batch_seqs = 16 # process multiple windows at once for speed

for bi in range(0, total_windows, batch_seqs):
batch_ws = window_starts[bi:bi + batch_seqs]
bsz = len(batch_ws)

x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
wlens = []
for i, ws in enumerate(batch_ws):
end = min(ws + seq_len, total_tokens)
wlen = end - ws
wlens.append(wlen)
chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device)
x_batch[i, :wlen] = chunk[:-1]
y_batch[i, :wlen] = chunk[1:]

# ── SCORE (no grad) ──
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model.forward_logits(x_batch)
nll = F.cross_entropy(
logits.reshape(-1, logits.size(-1)).float(),
y_batch.reshape(-1),
reduction="none",
).reshape(bsz, seq_len)

for i, ws in enumerate(batch_ws):
wlen = wlens[i]
s = 0 if ws == 0 else max(wlen - stride, 0)
scored_nll = nll[i, s:wlen].to(torch.float64)
loss_sum += scored_nll.sum()
token_count += float(wlen - s)
tgt = y_batch[i, s:wlen]
prev = x_batch[i, s:wlen]
tb = base_bytes_lut[tgt].to(torch.float64)
tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64)
byte_count += tb.sum()

# ── TTT UPDATE on scored windows ──
if optimizer is not None:
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
train_loss = model(x_batch, y_batch)
train_loss.backward()
optimizer.step()

if bi % (5000 * batch_seqs) == 0 or bi + batch_seqs >= total_windows:
elapsed = time.perf_counter() - t0
bpb = (loss_sum.item() / max(token_count.item(), 1)) / math.log(2.0) * \
(token_count.item() / max(byte_count.item(), 1))
windows_done = min(bi + batch_seqs, total_windows)
print(f" [{windows_done}/{total_windows}] bpb={bpb:.6f} t={elapsed:.0f}s")

val_loss = (loss_sum / token_count).item()
val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item())
print(f"done: val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} "
f"tokens={token_count.item():.0f} elapsed={time.perf_counter() - t0:.0f}s")
return val_bpb


def main():
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

sp = spm.SentencePieceProcessor(model_file=TOKENIZER_PATH)
val_tokens = load_validation_tokens(os.path.join(DATA_PATH, "fineweb_val_*.bin"), SEQ_LEN)
luts = build_sentencepiece_luts(sp, VOCAB_SIZE, device)
print(f"val tokens: {val_tokens.numel() - 1}, stride: {STRIDE}")

model = GPT_SP4608(prime_rank=PRIME_RANK, prime_layers=PRIME_LAYERS).to(device).bfloat16()

sd = torch.load(CHECKPOINT, map_location="cpu", weights_only=True)
model_sd = model.state_dict()
loaded = 0
for k, v in sd.items():
if k in model_sd and model_sd[k].shape == v.shape:
model_sd[k] = v
loaded += 1
model.load_state_dict(model_sd)
print(f"loaded {loaded}/{len(sd)} keys")
print(f"total params: {sum(p.numel() for p in model.parameters()):,}")
print(f"prime params: {sum(p.numel() for _, p in model.prime_named_params()):,}")

for p in model.parameters():
p.requires_grad_(False)
for _, p in model.prime_named_params():
p.requires_grad_(True)

# Baseline (sliding window, no TTT)
print(f"\n=== Baseline sliding window (stride={STRIDE}, no TTT) ===")
reset_primes(model)
bl = eval_sliding_ttt(model, val_tokens, device, 0.0, STRIDE, *luts)
print(f"baseline: {bl:.6f}")

# TTT LR sweep
results = {}
for lr in [0.003, 0.01, 0.03, 0.1]:
print(f"\n=== TTT lr={lr} (sliding, stride={STRIDE}) ===")
reset_primes(model)
bpb = eval_sliding_ttt(model, val_tokens, device, lr, STRIDE, *luts)
results[lr] = bpb
print(f"lr={lr}: {bpb:.6f} ({bpb - bl:+.6f})")

# Summary
print(f"\n{'='*50}")
print(f"SUMMARY (sliding window stride={STRIDE}, baseline={bl:.6f})")
print(f"{'='*50}")
for lr, bpb in sorted(results.items()):
print(f" lr={lr:6.3f}: {bpb:.6f} ({bpb - bl:+.6f})")


if __name__ == "__main__":
main()
Loading