Skip to content

Commit 5969a30

Browse files
sunnypatneediclaude
andcommitted
Add AdamW TTT (PR openai#481 recipe) to submission script
Upgrades TTT from PR openai#549's weak 3ep SGD (-0.0025 bpb) to PR openai#481's proven AdamW 30ep cosine + per-layer LR recipe (expected -0.01 to -0.025). Changes: - train_gpt.py: Added _ttt_run_phase() + ttt_adapt() + TTT hyperparams - run_3seeds.sh: Added TTT env vars for 3-seed validation - finalize_submission.py: Extracts pre/post TTT metrics from logs - README.md + submission.json: Updated for TTT-enabled submission Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 53d1c27 commit 5969a30

5 files changed

Lines changed: 1954 additions & 32 deletions

File tree

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,88 @@
1-
## Record: PLACEHOLDER_TECHNIQUE_NAME
1+
# LeakyReLU(0.5)^2 + AdamW TTT (30ep cosine + per-layer LR) + XSA + Int6
22

3-
**val_bpb: PLACEHOLDER** (3-seed mean) | **PLACEHOLDER MB** artifact | 8xH100 SXM, 600s
3+
**val_bpb: FILL_BPB** (3-seed mean) | **FILL_MB MB** artifact | 8xH100 SXM, 600s train + ~585s eval
44

5-
### Results (3 seeds, 8xH100 SXM)
5+
## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128)
66

7-
| Seed | Steps | Sliding BPB (s64) | Artifact |
8-
|------|-------|-------------------|----------|
9-
| 42 | XXXX | X.XXXX | XX.XX MB |
10-
| 1337 | XXXX | X.XXXX | XX.XX MB |
11-
| 2024 | XXXX | X.XXXX | XX.XX MB |
7+
| Seed | Steps | Pre-TTT BPB | Post-TTT BPB (s64) | Artifact |
8+
|------|-------|-------------|---------------------|----------|
9+
| 42 | FILL | FILL | FILL | FILL |
10+
| 1337 | FILL | FILL | FILL | FILL |
11+
| 2024 | FILL | FILL | FILL | FILL |
1212

13-
**Mean: X.XXXX | Std: X.XXXX**
13+
**Mean: FILL | Std: FILL**
1414

15-
### Key Innovations
15+
## Key Innovation: AdamW TTT with cosine + per-layer LR on SOTA base
1616

17-
PLACEHOLDER — describe what's new vs prior SOTA.
17+
The merged SOTA (PR #549, 1.1194) uses a weak 3-epoch SGD TTT that gives only -0.0025 bpb. We replace it with PR #481's proven AdamW recipe:
1818

19-
### Architecture
19+
1. **AdamW optimizer** (weight_decay=0) instead of SGD with momentum
20+
2. **30 epochs** with **cosine LR decay** instead of 3 epochs flat
21+
3. **Per-layer LR groups**: MLP output projections get 3x base LR (more quant-damaged), MLP input projections get 0.5x, everything else 1x
22+
4. **All blocks unfrozen** (freeze_blocks=0)
2023

21-
- 11 layers, 512 dim, 8 heads / 4 KV heads (GQA)
22-
- PLACEHOLDER — list all components
24+
PR #481 demonstrated this recipe gives -0.066 bpb on their base (1.1577 -> 1.0970). On the stronger PR #549 base (~1.12 pre-TTT), we expect -0.010 to -0.025 bpb.
2325

24-
### Training Configuration
26+
## Architecture (from PR #549 SOTA)
2527

26-
- PLACEHOLDER — optimizer, LR, batch size, warmdown
28+
| Component | Setting |
29+
|-----------|---------|
30+
| Layers | 11 (512d, 8H, 4KV GQA) |
31+
| MLP | 3x expansion, **LeakyReLU(0.5)^2** |
32+
| BigramHash | 2048 |
33+
| XSA | Last 4 layers |
34+
| RoPE | Partial (16/64 dims) |
35+
| LN Scale | 1/sqrt(layer+1) |
36+
| VE128 | Layers 9-10 |
37+
| Weight avg | EMA(0.997) + SWA(every 50) |
38+
| Quantization | GPTQ-lite int6 + zstd-22 |
2739

28-
### Quantization
40+
## TTT Configuration
2941

30-
- PLACEHOLDER — int5/int6, GPTQ-lite, zstd-22
42+
| Parameter | Value |
43+
|-----------|-------|
44+
| Optimizer | AdamW (weight_decay=0) |
45+
| Base LR | 0.0005 |
46+
| Per-layer LR | mlp.proj: 3x, mlp.fc: 0.5x, other: 1x |
47+
| Epochs | 30 |
48+
| Schedule | Cosine decay |
49+
| Freeze blocks | 0 (all unfrozen) |
50+
| Batch seqs | 64 per GPU (512 total) |
51+
| Max steps/epoch | 300 |
3152

32-
### Run Command
53+
## Timing Budget
54+
55+
| Phase | Time |
56+
|-------|------|
57+
| Training | 600s (10 min) |
58+
| Int6 roundtrip eval (diagnostic) | ~20s |
59+
| AdamW TTT (30 epochs) | ~465s |
60+
| Sliding window eval (stride=64) | ~120s |
61+
| **Total eval** | **~605s (within 10 min)** |
62+
63+
## Run Command
3364

3465
```bash
35-
SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py
66+
cd /workspace/parameter-golf
67+
SEED=42 XSA_LAST_N=4 TTT_ENABLED=1 TTT_LR=0.0005 TTT_EPOCHS=30 \
68+
TTT_COSINE=1 TTT_PERLAYER=1 TTT_FREEZE_BLOCKS=0 TTT_BATCH_SEQS=64 \
69+
torchrun --standalone --nproc_per_node=8 \
70+
records/track_10min_16mb/2026-03-24_sunnypatneedi_submission/train_gpt.py
3671
```
3772

38-
### Provenance
73+
## Provenance
3974

40-
Built on PR #414 (signalrush, merged SOTA 1.1228). Key additions from:
41-
- PLACEHOLDER — list PRs and papers we build on
75+
Built on PR #549 (abaybektursun, merged SOTA 1.1194), with TTT recipe from PR #481 (mrdavtan, 1.0970):
76+
- PR #549 / PR #414 (signalrush) - base architecture, int6 GPTQ-lite, EMA/SWA, LeakyReLU
77+
- PR #481 (mrdavtan) - AdamW TTT with cosine decay and per-layer LR
78+
- PR #198 / PR #503 (jfprincz) - XSA (exclusive self-attention)
79+
- PR #287 (jfprincz) - Partial RoPE + LN Scale
4280

43-
### Test Plan
81+
## Test Plan
4482

4583
- [ ] 3 seeds run on 8xH100 SXM
4684
- [ ] All 3 seeds train in <=600s
85+
- [ ] All 3 seeds total eval (TTT + sliding) in <=600s
4786
- [ ] All 3 seeds artifact <=16,000,000 bytes
48-
- [ ] Sliding window eval (stride=64) consistent
87+
- [ ] Post-TTT sliding BPB beats 1.1194 by >=0.005 nats
88+
- [ ] Statistical significance p<0.01 across 3 seeds
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Post-run script: reads 3 seed logs, fills in README.md and submission.json.
4+
Run locally after scp-ing logs from RunPod.
5+
6+
Usage:
7+
python3 finalize_submission.py [submission_dir]
8+
# defaults to the directory containing this script
9+
"""
10+
import json
11+
import os
12+
import re
13+
import sys
14+
from pathlib import Path
15+
16+
def extract_metrics(log_path: str) -> dict:
17+
"""Extract key metrics from a training log."""
18+
text = Path(log_path).read_text()
19+
metrics = {}
20+
21+
# Pre-TTT BPB (int6 roundtrip before TTT)
22+
m = re.findall(r"final_int6_roundtrip_exact.*?val_bpb:([\d.]+)", text)
23+
if m:
24+
metrics["pre_ttt_bpb"] = float(m[-1])
25+
26+
# BPB from sliding window eval (the submission score — post-TTT)
27+
m = re.findall(r"final_int6_sliding_window_exact.*?val_bpb:([\d.]+)", text)
28+
if m:
29+
metrics["bpb"] = float(m[-1])
30+
31+
# Artifact size
32+
m = re.findall(r"Total submission size.*?(\d+)\s*bytes", text)
33+
if m:
34+
metrics["artifact"] = int(m[-1])
35+
36+
# Steps
37+
m = re.findall(r"stopping_early.*?step[: ]*(\d+)", text)
38+
if not m:
39+
m = re.findall(r"step[: ]*(\d+)", text)
40+
if m:
41+
metrics["steps"] = int(m[-1])
42+
43+
return metrics
44+
45+
46+
def main():
47+
sub_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path(__file__).parent
48+
seeds = [42, 1337, 2024]
49+
results = {}
50+
51+
print("Extracting metrics from logs...")
52+
for seed in seeds:
53+
log = sub_dir / f"train_seed{seed}.log"
54+
if not log.exists():
55+
print(f" WARNING: {log} not found")
56+
continue
57+
m = extract_metrics(str(log))
58+
results[seed] = m
59+
print(f" Seed {seed}: bpb={m.get('bpb', '?')}, artifact={m.get('artifact', '?')}, steps={m.get('steps', '?')}")
60+
61+
if len(results) < 3:
62+
print(f"\nERROR: Only found {len(results)}/3 seed logs. Cannot finalize.")
63+
sys.exit(1)
64+
65+
bpbs = [results[s]["bpb"] for s in seeds]
66+
mean_bpb = sum(bpbs) / len(bpbs)
67+
std_bpb = (sum((x - mean_bpb) ** 2 for x in bpbs) / len(bpbs)) ** 0.5
68+
max_artifact = max(results[s]["artifact"] for s in seeds)
69+
mean_artifact_mb = sum(results[s]["artifact"] for s in seeds) / 3 / 1_000_000
70+
71+
print(f"\n Mean BPB: {mean_bpb:.4f} (std {std_bpb:.4f})")
72+
print(f" Max artifact: {max_artifact} bytes ({max_artifact/1_000_000:.2f} MB)")
73+
74+
# Validation checks
75+
sota = 1.1194
76+
delta = mean_bpb - sota
77+
print(f"\n vs SOTA ({sota}): {delta:+.4f} nats")
78+
if delta < -0.005:
79+
print(f" PASS: Beats SOTA by {abs(delta):.4f} nats")
80+
elif delta < 0:
81+
print(f" CLOSE: Improves by {abs(delta):.4f} nats but < 0.005 threshold")
82+
print(f" Consider submitting as non-record if techniques are novel.")
83+
else:
84+
print(f" DOES NOT BEAT SOTA. Consider as non-record submission.")
85+
86+
if max_artifact > 16_000_000:
87+
print(f" FAIL: Artifact exceeds 16MB ({max_artifact} bytes)")
88+
else:
89+
print(f" PASS: All artifacts under 16MB")
90+
91+
# Update submission.json
92+
json_path = sub_dir / "submission.json"
93+
sj = json.loads(json_path.read_text())
94+
sj["val_bpb"] = round(mean_bpb, 4)
95+
sj["bytes_total"] = max_artifact
96+
sj["blurb"] = (
97+
f"LeakyReLU(0.5)^2 activation + XSA on last 4 layers + Partial RoPE + LN Scale "
98+
f"+ VE128 + EMA/SWA + GPTQ-lite int6 + zstd-22. "
99+
f"Built on PR #549 stack. 3-seed mean: {mean_bpb:.4f} (std {std_bpb:.4f}). "
100+
f"All artifacts under 16MB."
101+
)
102+
json_path.write_text(json.dumps(sj, indent=2) + "\n")
103+
print(f"\n Updated {json_path}")
104+
105+
# Update README.md
106+
readme_path = sub_dir / "README.md"
107+
readme = readme_path.read_text()
108+
109+
# Fill header
110+
readme = readme.replace("FILL_BPB", f"{mean_bpb:.4f}")
111+
readme = readme.replace("FILL_MB", f"{mean_artifact_mb:.2f}")
112+
113+
# Fill results table
114+
for seed in seeds:
115+
r = results[seed]
116+
old_line = f"| {seed} | FILL | FILL | FILL |"
117+
new_line = (
118+
f"| {seed} | {r.get('steps', '?')} | {r['bpb']:.4f} "
119+
f"| {r['artifact']/1_000_000:.2f} MB |"
120+
)
121+
readme = readme.replace(old_line, new_line)
122+
123+
# Fill mean/std
124+
readme = readme.replace("**Mean: FILL | Std: FILL**", f"**Mean: {mean_bpb:.4f} | Std: {std_bpb:.4f}**")
125+
126+
readme_path.write_text(readme)
127+
print(f" Updated {readme_path}")
128+
129+
print(f"\n{'='*50}")
130+
print("SUBMISSION READY. Next steps:")
131+
print(f" 1. Review README.md and submission.json")
132+
print(f" 2. git checkout -b submission/sunnypatneedi-leakyrelu-xsa")
133+
print(f" 3. git add {sub_dir.relative_to(sub_dir.parent.parent.parent)}/")
134+
print(f" 4. git commit -m 'Add submission: LeakyReLU + XSA'")
135+
print(f" 5. git push origin submission/sunnypatneedi-leakyrelu-xsa")
136+
print(f" 6. Open PR at: https://github.com/openai/parameter-golf/compare")
137+
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)