Skip to content

Commit 717961b

Browse files
Nishant Abhangiclaude
andcommitted
autoresearch: Modal infra, experiment scripts, and training variants
- Autoresearch loop (program.md, loop.sh, generate_next.py) - Modal provider for 8xH100 training with checkpoint save/restore - Experiment framework with preflight size checks - eval_ttt.py for TTT evaluation against saved checkpoints - train_gpt_improved.py: PR openai#569 base (VRL, GPTQ, LeakyReLU², pruning) - train_gpt_576.py: PR openai#576 base (int5, 33.6M params, score-first TTT) - train_gpt_sota.py: PR openai#573 base - train_gpt_mlx_recurrent.py: depth recurrence experiments - Benchmark scripts for local MLX A/B testing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2376a1e commit 717961b

22 files changed

Lines changed: 8944 additions & 0 deletions

autoresearch/best_ideas.md

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Best Ideas for Parameter Golf
2+
3+
Ranked by expected impact, based on analysis of top leaderboard submissions (as of 2026-03-23).
4+
5+
## Current Merged SOTA: 1.1233 BPB (signalrush, PR #414)
6+
Architecture: 11L, d=512, 8 heads (4 KV, GQA), MLP 3x (1536), relu², U-Net skips,
7+
Efficient Partial XSA on last 4 layers, Partial RoPE (16/64 dims), LN Scale 1/sqrt(layer+1),
8+
SmearGate + BigramHash (2048 buckets), tied embeddings, logit softcap=30.
9+
Training: Muon (lr=0.025, WD=0.04), EMA (decay=0.997), SWA (every 50 steps when scale<0.2),
10+
Late QAT (STE int6 @ lr_scale<0.15), warmdown=3500 iters. GPTQ-lite clip search post-training.
11+
12+
## Open PR SOTA: 1.0672 BPB (JoeProAI, PR #462)
13+
Same base + SwiGLU/StarReLU MLP (hidden=1792), XSA on ALL layers, U-Net + AdamW TTT at eval.
14+
TTT is the biggest single win (~0.05 BPB). Architecture changes that improve TTT are highest priority.
15+
16+
## Tier 1: High Impact (proven on leaderboard)
17+
18+
### 1. AdamW TTT (Test-Time Training)
19+
- Fine-tune ALL model weights on validation data at eval time
20+
- AdamW optimizer, lr=0.0005, 10-30 epochs, cosine schedule
21+
- Grad clip 1.0, all layers unfrozen
22+
- **Impact**: -0.053 to -0.061 BPB (massive)
23+
- **Key insight**: Architecture matters for TTT effectiveness. U-Net + gated skips create
24+
smoother loss geometry that TTT can exploit more effectively.
25+
- **Note**: TTT happens at eval time on H100s. For MLX autoresearch, focus on the
26+
architecture that maximizes TTT effectiveness, not TTT itself.
27+
28+
### 2. U-Net Skip Connections
29+
- Split layers into encoder (first N) and decoder (last M)
30+
- Encoder layers push outputs onto a stack; decoder layers pop + combine via learned sigmoid gates
31+
- `gate * x + (1-gate) * (skip_weight * skip)` where gate and skip_weight are per-dim learnable
32+
- **Impact**: Enables 2.8x more TTT gain vs standard architecture
33+
- **Synergy**: Critical for TTT effectiveness
34+
35+
### 3. XSA (Exclusive Self-Attention)
36+
- After standard attention output y = softmax(QK^T)V, subtract self-value projection:
37+
`y_out = y - proj(y, normalize(v))`
38+
- Forces attention to encode novel cross-token information, not repeat values
39+
- Apply to ALL layers (not just last 4) for best results
40+
- **Impact**: -0.002 BPB standalone, but compounds with other techniques
41+
- **Cost**: ~3ms/step extra on H100
42+
43+
### 4. SwiGLU / StarReLU MLP
44+
- Replace relu² with StarReLU: `relu(x)^2 * scale + bias` (per-channel learnable)
45+
- Or SwiGLU gating: `silu(W_gate * x) * (W_up * x)`
46+
- Top submission (#462) uses StarReLU with hidden_dim=1792
47+
- **Impact**: Improves both base model and TTT effectiveness
48+
49+
### 5. EMA (Exponential Moving Average)
50+
- Maintain shadow copy of all weights: `ema = decay * ema + (1-decay) * weights`
51+
- decay=0.997, applied every step, stored in fp32
52+
- Use EMA weights as base for quantization (smoother → less quantization damage)
53+
- **Impact**: -0.001 to -0.002 BPB, improves quantization quality
54+
55+
### 6. Partial RoPE + LN Scale (NEW — merged SOTA uses both)
56+
- Partial RoPE: Apply RoPE to only 16/64 dims (25%). Rest are position-free.
57+
Helps generalization and reduces positional overfitting.
58+
- LN Scale Factor: Scale LayerNorm output by `1/sqrt(layer_idx+1)`. Deeper layers get smaller
59+
residual contributions. Stabilizes training, especially with more layers.
60+
- **Impact**: Part of every recent SOTA. Easy to implement, no artifact size cost.
61+
62+
### 7. 11 Layers (not 9 or 10)
63+
- All merged SOTAs since 1.1307 use 11 layers with int6 quantization.
64+
- Fits under 16MB with MLP 3x and int6.
65+
- **Impact**: More depth = better features, especially with U-Net skips.
66+
67+
## Tier 2: Medium Impact (proven but smaller gains)
68+
69+
### 6. Per-Layer TTT Learning Rates
70+
- Measure quantization error per layer type
71+
- Give 3x LR to MLP output projections (most damaged by quantization)
72+
- Give 0.5x LR to MLP input projections (least damaged)
73+
- **Impact**: +23.5% TTT improvement for free
74+
75+
### 7. GPTQ-lite Clip Search
76+
- Per-row optimal clipping for int quantization
77+
- Try 5 clip percentiles [0.999, 0.9995, 0.9999, 0.99999, 1.0], pick best MSE
78+
- **Impact**: -0.0006 BPB, zero training cost
79+
- **Cost**: Post-training only, simple to implement
80+
81+
### 8. BigramHash Embeddings
82+
- Hash-based bigram lookup table (4096-12288 entries)
83+
- Adds local context signal to token embeddings
84+
- **Impact**: -0.001 to -0.002 BPB
85+
86+
### 9. Mixed-Precision Quantization
87+
- Int5 for MLP weights, Int6 for attention weights
88+
- Bitpacking for sub-byte storage (critical for int5)
89+
- **Impact**: Frees ~20% bytes vs uniform int6, fund wider model or more layers
90+
91+
### 10. Stochastic Weight Averaging (SWA)
92+
- Snapshot model weights periodically during warmdown (every 50 steps when lr < 0.2)
93+
- Average snapshots for final model
94+
- Combined with EMA for dual averaging
95+
- **Impact**: -0.001 BPB, smoother weight distributions
96+
97+
## Tier 3: Speculative / Lower Priority
98+
99+
### 11. Train Larger, Quantize Harder
100+
- d=576 (27M params) at int5 instead of d=512 (22M) at int6
101+
- Lower pre-quant loss offsets coarser quantization
102+
- Needs extended QAT (start at lr_scale < 0.50, not 0.10)
103+
- **Impact**: Competitive but not yet proven better than optimized d=512
104+
105+
### 12. Value Embeddings (ResFormer)
106+
- Per-layer value embedding with input-dependent gating
107+
- Already in Karpathy's autoresearch baseline
108+
- **Impact**: Small but consistent improvement
109+
110+
### 13. Custom Compression
111+
- ANS/arithmetic coder tuned to weight distributions
112+
- Could beat zstd by 5-10%
113+
- **Impact**: ~1MB saved, funds more parameters
114+
115+
### 14. Structured Sparsity (2:4)
116+
- Halves MLP compute, enables wider model
117+
- **Impact**: Unclear under 16MB constraint
118+
119+
## Strategy Notes
120+
121+
- **For MLX autoresearch**: Focus on architecture (U-Net, XSA, MLP type, EMA) since
122+
TTT runs at eval time on H100. The goal is to build the architecture that responds
123+
best to TTT.
124+
- **Stack incrementally**: Test each technique in isolation, then combine winners.
125+
- **Relative signal**: MLX train loss at 500 iters is a reliable relative signal.
126+
Lower train loss locally → lower val_bpb on H100.
127+
- **Artifact size**: Always monitor. Some techniques (wider model, more layers, bigram
128+
hash) increase artifact size. Must stay under 16 MB.

autoresearch/generate_next.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""
2+
Adaptive experiment generator for Parameter Golf autoresearch.
3+
4+
Reads completed results, analyzes what worked/didn't, and generates
5+
the next batch of experiments. Called by loop.sh after each batch.
6+
7+
Strategy:
8+
1. Rank completed experiments by BPB
9+
2. Identify which technique changes improved vs hurt
10+
3. Generate new experiments that:
11+
a. Combine top-2 individual winners
12+
b. Push winning techniques further (e.g., if XSA-6 beat XSA-4, try XSA-8)
13+
c. Sweep around the best hyperparameters
14+
d. Try removing the worst-performing changes (simplify)
15+
4. Always include 1 "wild card" experiment for exploration
16+
17+
All generated experiments must be:
18+
- Legal TTT (TTT_PASSES ≤ 1)
19+
- Within 16MB artifact budget (estimated)
20+
- Not duplicates of already-run experiments
21+
"""
22+
23+
import json
24+
import sys
25+
import os
26+
from pathlib import Path
27+
28+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
29+
from experiments import (
30+
Experiment, ExperimentResult, estimate_artifact_bytes,
31+
MAX_ARTIFACT_BYTES, EXPERIMENTS,
32+
)
33+
34+
STATE_FILE = Path(__file__).resolve().parent / "state.json"
35+
EXPERIMENTS_FILE = Path(__file__).resolve().parent.parent / "experiments.py"
36+
37+
# Default env for legal experiments
38+
BASE_ENV = {"SEED": "1337", "TTT_PASSES": "1"}
39+
40+
41+
def load_state() -> dict:
42+
if STATE_FILE.exists():
43+
return json.loads(STATE_FILE.read_text())
44+
return {"results": [], "completed_experiments": []}
45+
46+
47+
def analyze_results(results: list[dict]) -> dict:
48+
"""Analyze what worked and what didn't."""
49+
if not results:
50+
return {"winners": [], "losers": [], "baseline_bpb": None}
51+
52+
# Find baseline (no_ttt or legal_ttt_baseline)
53+
baseline = None
54+
for r in results:
55+
if r.get("experiment") in ("no_ttt_baseline", "legal_ttt_baseline"):
56+
if r.get("val_bpb"):
57+
baseline = r
58+
break
59+
60+
if not baseline:
61+
# Use the first result with a BPB as reference
62+
for r in results:
63+
if r.get("val_bpb"):
64+
baseline = r
65+
break
66+
67+
if not baseline:
68+
return {"winners": [], "losers": [], "baseline_bpb": None}
69+
70+
base_bpb = baseline["val_bpb"]
71+
72+
# Classify experiments
73+
winners = [] # Lower BPB = better
74+
losers = []
75+
for r in results:
76+
if not r.get("val_bpb") or r["experiment"] == baseline["experiment"]:
77+
continue
78+
delta = r["val_bpb"] - base_bpb
79+
entry = {"experiment": r["experiment"], "val_bpb": r["val_bpb"], "delta": delta}
80+
if delta < -0.0005: # Improved by at least 0.0005
81+
winners.append(entry)
82+
elif delta > 0.001: # Hurt by more than 0.001
83+
losers.append(entry)
84+
85+
winners.sort(key=lambda x: x["delta"])
86+
losers.sort(key=lambda x: x["delta"], reverse=True)
87+
88+
return {
89+
"winners": winners,
90+
"losers": losers,
91+
"baseline_bpb": base_bpb,
92+
"baseline_experiment": baseline["experiment"],
93+
}
94+
95+
96+
def extract_env_diff(experiment_name: str) -> dict:
97+
"""Get the env overrides for a named experiment from EXPERIMENTS list."""
98+
for exp in EXPERIMENTS:
99+
if exp.name == experiment_name:
100+
# Return only the non-default overrides
101+
diff = {}
102+
for k, v in exp.env.items():
103+
if k == "SEED" or k == "TTT_PASSES":
104+
continue
105+
diff[k] = v
106+
return diff
107+
return {}
108+
109+
110+
def generate_next_batch(analysis: dict, completed: list[str], batch_size: int = 5) -> list[Experiment]:
111+
"""Generate the next batch of experiments based on results analysis."""
112+
new_experiments = []
113+
used_names = set(completed)
114+
115+
def _add(name, desc, env_extra=None, patches=None):
116+
if name in used_names or len(new_experiments) >= batch_size:
117+
return
118+
env = {**BASE_ENV}
119+
if env_extra:
120+
env.update(env_extra)
121+
exp = Experiment(name=name, description=desc, env=env, patches=patches or [])
122+
est = estimate_artifact_bytes(env)
123+
if est <= MAX_ARTIFACT_BYTES:
124+
new_experiments.append(exp)
125+
used_names.add(name)
126+
127+
winners = analysis.get("winners", [])
128+
losers = analysis.get("losers", [])
129+
130+
# Strategy 1: Combine top-2 winners
131+
if len(winners) >= 2:
132+
w1_env = extract_env_diff(winners[0]["experiment"])
133+
w2_env = extract_env_diff(winners[1]["experiment"])
134+
combined_env = {**w1_env, **w2_env}
135+
name = f"combo_{winners[0]['experiment']}_plus_{winners[1]['experiment']}"[:60]
136+
desc = f"Combine #{1} {winners[0]['experiment']} ({winners[0]['delta']:+.4f}) + #{2} {winners[1]['experiment']} ({winners[1]['delta']:+.4f})"
137+
_add(name, desc, combined_env)
138+
139+
# Strategy 2: Combine top-3 winners
140+
if len(winners) >= 3:
141+
combined_env = {}
142+
for w in winners[:3]:
143+
combined_env.update(extract_env_diff(w["experiment"]))
144+
name = "combo_top3_winners"
145+
desc = f"Combine top 3: {', '.join(w['experiment'] for w in winners[:3])}"
146+
_add(name, desc, combined_env)
147+
148+
# Strategy 3: Push winning hyperparameters further
149+
for w in winners[:3]:
150+
env_diff = extract_env_diff(w["experiment"])
151+
for key, val in env_diff.items():
152+
try:
153+
fval = float(val)
154+
# If this was an increase from baseline, try going further
155+
# If it was a decrease, try going even lower
156+
for mult, suffix in [(1.5, "more"), (0.5, "less")]:
157+
new_val = fval * mult
158+
name = f"{w['experiment']}_{suffix}"
159+
desc = f"Push {key}={new_val} ({suffix} than {val})"
160+
_add(name, desc, {key: str(new_val)})
161+
except (ValueError, TypeError):
162+
pass
163+
164+
# Strategy 4: Interpolate between winner and baseline
165+
for w in winners[:2]:
166+
env_diff = extract_env_diff(w["experiment"])
167+
for key, val in env_diff.items():
168+
try:
169+
fval = float(val)
170+
# Try halfway between baseline default and winning value
171+
# (We don't know the baseline default here, so skip this for now)
172+
pass
173+
except (ValueError, TypeError):
174+
pass
175+
176+
# Strategy 5: Wild card — try something not yet tested
177+
wild_cards = [
178+
("seq_len_4096", "Longer sequence length (4096 vs 2048)", {"TRAIN_SEQ_LEN": "4096", "EVAL_SEQ_LEN": "4096"}),
179+
("rope_base_50k", "Higher RoPE base (50000 vs 10000)", {"ROPE_BASE": "50000"}),
180+
("softcap_50", "Higher logit softcap (50 vs 30)", {"LOGIT_SOFTCAP": "50.0"}),
181+
("softcap_20", "Lower logit softcap (20 vs 30)", {"LOGIT_SOFTCAP": "20.0"}),
182+
("qk_gain_2", "Higher QK gain init (2.0 vs 1.5)", {"QK_GAIN_INIT": "2.0"}),
183+
("muon_momentum_095", "Lower Muon momentum (0.95 vs 0.99)", {"MUON_MOMENTUM": "0.95"}),
184+
("embed_lr_08", "Higher embed LR (0.8 vs 0.6)", {"EMBED_LR": "0.8"}),
185+
]
186+
for name, desc, env_extra in wild_cards:
187+
_add(name, desc, env_extra)
188+
if len(new_experiments) >= batch_size:
189+
break
190+
191+
return new_experiments
192+
193+
194+
def main():
195+
state = load_state()
196+
results = state.get("results", [])
197+
completed = state.get("completed_experiments", [])
198+
199+
print(f"Completed experiments: {len(completed)}")
200+
print(f"Results with BPB: {sum(1 for r in results if r.get('val_bpb'))}")
201+
202+
analysis = analyze_results(results)
203+
print(f"\nBaseline: {analysis.get('baseline_bpb', 'N/A')}")
204+
print(f"Winners ({len(analysis['winners'])}):")
205+
for w in analysis["winners"]:
206+
print(f" {w['delta']:+.4f}{w['experiment']} ({w['val_bpb']:.4f})")
207+
print(f"Losers ({len(analysis['losers'])}):")
208+
for l in analysis["losers"]:
209+
print(f" {l['delta']:+.4f}{l['experiment']} ({l['val_bpb']:.4f})")
210+
211+
new_batch = generate_next_batch(analysis, completed)
212+
print(f"\nGenerated {len(new_batch)} new experiments:")
213+
for exp in new_batch:
214+
est = estimate_artifact_bytes(exp.env)
215+
print(f" {exp.name}: {exp.description} (~{est/1e6:.1f} MB)")
216+
217+
if new_batch:
218+
# Append to experiments.py EXPERIMENTS list
219+
# Write them as a separate file that the provider can pick up
220+
out = Path(__file__).resolve().parent / "next_batch.json"
221+
batch_data = []
222+
for exp in new_batch:
223+
batch_data.append({
224+
"name": exp.name,
225+
"description": exp.description,
226+
"env": exp.env,
227+
"patches": exp.patches,
228+
})
229+
out.write_text(json.dumps(batch_data, indent=2))
230+
print(f"\nWrote {len(new_batch)} experiments to {out}")
231+
232+
return new_batch
233+
234+
235+
if __name__ == "__main__":
236+
main()

0 commit comments

Comments
 (0)