Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions records/track_non_record_16mb/2026-04-10_AdamTTT_QKGain/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Non-Record Submission: Adam-TTT + QK-Gain Search

**Track:** non-record (no training logs — compute unavailable before deadline)
**Base:** PR #1493 (bigbag, 1.0810 BPB)

This submission presents a novel algorithmic idea — replacing SGD with Adam in the TTT eval loop — without validation results due to lack of compute access. Submitted to put the idea on the record and for community review.

## Key Changes

### 1. Adam-TTT (primary novel contribution)
The SOTA TTT loop uses SGD with momentum=0.9. We replace it with **Adam** (beta1=0.9, beta2=0.999).

**Why Adam is better for TTT:**
- Adam adapts per-parameter learning rates based on gradient history
- TTT is essentially fast fine-tuning — Adam is the standard choice for fine-tuning
- SGD with fixed momentum can overshoot on some parameters while undershooting on others
- Adam's normalization by second moment makes it less sensitive to LR choice

Controlled by `TTT_OPTIMIZER=adam` (default). Also supports `TTT_OPTIMIZER=adamw` (with weight decay) and `TTT_OPTIMIZER=sgd` (original behavior).

New hyperparameters:
- `TTT_OPTIMIZER` — optimizer choice (default: `adam`)
- `TTT_ADAM_BETA1` — Adam beta1 (default: `0.9`)
- `TTT_ADAM_BETA2` — Adam beta2 (default: `0.999`)
- `TTT_ADAM_EPS` — Adam epsilon (default: `1e-8`)
- `TTT_WEIGHT_DECAY` — weight decay for AdamW (default: `0.0`)
- `TTT_FRESH_OPTIMIZER` — reset optimizer state each chunk (default: `0`)

### 2. QK-Gain above 5.25
The SOTA showed monotonic improvement from QK_GAIN_INIT=4.0→4.5→5.0→5.25.
We search 5.5, 6.0, 6.5 to find the optimum.

## Architecture (unchanged from PR #1493)

11L × 512d × 8H / 4KV, MLP 4×, LeakyReLU(0.5)², Partial RoPE (16/64 dims),
layerwise LN scale, tied embeddings, logit softcap=30.0.
Depth recurrence: loops layers 3-5, activated at step ~35%.
Parallel residuals from layer 7.

## Experiment Plan

| Experiment | TTT_OPTIMIZER | QK_GAIN_INIT | Expected |
|---|---|---|---|
| 1 (Adam baseline) | adam | 5.25 | vs SOTA 1.0810 |
| 2 (higher QK-gain) | adam | 5.5 | better? |
| 3 (even higher) | adam | 6.0 | better? |
| 4 (AdamW TTT) | adamw | best above | better? |
| 5 (fresh Adam per chunk) | adam + fresh | best above | better? |

## Reproduction

```bash
pip install brotli sentencepiece
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
python3 data/cached_challenge_fineweb.py --variant sp8192

# Experiment 1: Adam-TTT at QK-gain 5.25
SEED=42 QK_GAIN_INIT=5.25 TTT_ENABLED=1 TTT_OPTIMIZER=adam TTT_LR=0.005 TTT_EPOCHS=3 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

# Experiment 2: Higher QK-gain
SEED=42 QK_GAIN_INIT=5.5 TTT_ENABLED=1 TTT_OPTIMIZER=adam TTT_LR=0.005 TTT_EPOCHS=3 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Credits

- **@bigbag** — PR #1493 SOTA stack (all core techniques)
- **Lakshay Bansal** — Adam-TTT replacement, QK-gain search
162 changes: 162 additions & 0 deletions records/track_non_record_16mb/2026-04-10_AdamTTT_QKGain/smoke_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
CPU smoke test — verifies the Adam-TTT changes work without GPU/flash_attn.
Run: python3 smoke_test.py
Expected: prints "SMOKE TEST PASSED" at the end.
"""
import math, os, sys, types
import torch
import torch.nn.functional as F
from torch import nn, Tensor

# ── Stub out flash_attn so we can run on CPU ──────────────────────────────────
def _fake_flash_attn(q, k, v, causal=True):
# q/k/v: (B, T, H, D) — reshape to (B, H, T, D) for sdpa
B, T, H, D = q.shape
Hkv = k.shape[2]
q2 = q.permute(0,2,1,3)
k2 = k.permute(0,2,1,3).repeat_interleave(H//Hkv, dim=1)
v2 = v.permute(0,2,1,3).repeat_interleave(H//Hkv, dim=1)
out = F.scaled_dot_product_attention(q2, k2, v2, is_causal=causal)
return out.permute(0,2,1,3)

stub = types.ModuleType("flash_attn_interface")
stub.flash_attn_func = _fake_flash_attn
sys.modules["flash_attn_interface"] = stub

# ── Now import our modified training code ────────────────────────────────────
import lzma, base64, importlib

train_script = open(os.path.join(os.path.dirname(__file__), "train_gpt.py")).read()
lines = train_script.strip().split('\n')
line2 = lines[1]
import re
m = re.search(r'B\.b85decode\("(.+?)"\),format=', line2)
payload = m.group(1)
code = lzma.decompress(base64.b85decode(payload), format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2}]).decode('utf-8')

# Execute into a fresh namespace
ns = {}
# Patch CUDA checks before exec
code = code.replace(
"if not torch.cuda.is_available():raise RuntimeError('CUDA is required')",
"pass # CPU smoke test: skip CUDA check"
)
code = code.replace(
"device=torch.device('cuda',local_rank);torch.cuda.set_device(device)",
"device=torch.device('cpu')"
)
code = code.replace(
"torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True;torch.set_float32_matmul_precision('high');from torch.backends.cuda import enable_cudnn_sdp,enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp;enable_cudnn_sdp(False);enable_flash_sdp(True);enable_mem_efficient_sdp(False);enable_math_sdp(False);",
""
)
# Disable torch.compile for CPU smoke test (avoids inductor bugs on Python 3.12)
code = code.replace("@torch.compile\n", "")
code = code.replace("@torch.compile(", "@torch.no_grad() # was compile(")
exec(code, ns)

# ── Test 1: Model forward pass ────────────────────────────────────────────────
print("Test 1: Model forward pass...")

class FakeH:
vocab_size = 64
num_layers = 4
model_dim = 32
embedding_dim = 32
num_heads = 4
num_kv_heads = 2
mlp_mult = 2.0
rope_base = 10000.0
qk_gain_init = 5.25 # our new default
train_seq_len = 16
logit_softcap = 30.0
rope_dims = 8
rope_train_seq_len = 16
ln_scale = True
skip_gates_enabled = True
tie_embeddings = True
tied_embed_init_std = 0.005
xsa_last_n = 0
parallel_residual_start = 2
num_loops = 1
loop_start = 1
loop_end = 2
enable_looping_at = 0.35

h = FakeH()
GPT = ns['GPT']
model = GPT(h)
model.looping_active = True

x = torch.randint(0, 64, (2, 16))
y = torch.randint(0, 64, (2, 16))
loss = model(x, y)
assert loss.item() > 0, "Loss should be positive"
print(f" Forward pass OK. Loss: {loss.item():.4f}, QK-gain default: {h.qk_gain_init}")

# ── Test 2: _make_ttt_opt helper ─────────────────────────────────────────────
print("Test 2: TTT optimizer selection...")

_make_ttt_opt = ns['_make_ttt_opt']

class TttH:
ttt_optimizer = 'sgd'
ttt_lr = 0.005
ttt_momentum = 0.9
ttt_adam_beta1 = 0.9
ttt_adam_beta2 = 0.999
ttt_adam_eps = 1e-8
ttt_weight_decay = 0.0

params = list(model.parameters())

# SGD
th = TttH(); th.ttt_optimizer = 'sgd'
opt = _make_ttt_opt(th, params)
assert isinstance(opt, torch.optim.SGD), "Expected SGD"
print(" SGD: OK")

# Adam (our default)
th.ttt_optimizer = 'adam'
opt = _make_ttt_opt(th, params)
assert isinstance(opt, torch.optim.Adam), "Expected Adam"
print(" Adam: OK")

# AdamW
th.ttt_optimizer = 'adamw'
opt = _make_ttt_opt(th, params)
assert isinstance(opt, torch.optim.AdamW), "Expected AdamW"
print(" AdamW: OK")

# ── Test 3: Adam-TTT step works ───────────────────────────────────────────────
print("Test 3: Adam-TTT gradient step...")

th.ttt_optimizer = 'adam'
opt = _make_ttt_opt(th, params)
opt.zero_grad()
loss = model(x, y)
loss.backward()
opt.step()
print(f" Adam step OK. Loss after step: {model(x, y).item():.4f}")

# ── Test 4: Hyperparameter class has new fields ────────────────────────────────
print("Test 4: Hyperparameter class has new fields...")
# Set env vars and check
os.environ['TTT_OPTIMIZER'] = 'adam'
os.environ['QK_GAIN_INIT'] = '5.25'
os.environ['TTT_FRESH_OPTIMIZER'] = '0'
Hyperparameters = ns['Hyperparameters']
h2 = Hyperparameters()
assert hasattr(h2, 'ttt_optimizer'), "Missing ttt_optimizer"
assert hasattr(h2, 'ttt_adam_beta1'), "Missing ttt_adam_beta1"
assert hasattr(h2, 'ttt_fresh_optimizer'), "Missing ttt_fresh_optimizer"
assert h2.ttt_optimizer == 'adam', f"Expected adam, got {h2.ttt_optimizer}"
assert h2.qk_gain_init == 5.25, f"Expected 5.25, got {h2.qk_gain_init}"
print(f" ttt_optimizer={h2.ttt_optimizer}, qk_gain_init={h2.qk_gain_init}")

print("\nSMOKE TEST PASSED")
print("\nNext steps:")
print(" 1. Apply for compute grant (OpenAI RunPod link in README)")
print(" 2. On RunPod 8xH100: run Experiment 1 (Adam-TTT, QK=5.25)")
print(" 3. Compare val_bpb to SOTA 1.0810")
print(" 4. Run Experiments 2-5 based on results")
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"name": "AdamTTT + QK-Gain Search",
"author": "Lakshay Bansal",
"github_id": "lakshaybansal",
"val_bpb": null,
"artifact_bytes": null,
"track": "non-record",
"notes": "Replaces SGD with Adam in TTT loop. Also searches QK-Gain above 5.25. Built on top of PR #1493 (bigbag SOTA). No training logs — submitted as novel algorithmic idea without compute access.",
"base_pr": 1493,
"techniques": [
"SP8192",
"3-layer depth recurrence",
"parallel residuals (layers 7+)",
"QK-gain tuning",
"Adam-TTT (novel: replace SGD with Adam in test-time training)",
"GPTQ SDClip int6/int8",
"MuonEq-R",
"EMA 0.9965",
"LZMA code compression"
],
"key_change": "TTT optimizer changed from SGD (momentum=0.9) to Adam (beta1=0.9, beta2=0.999). Adam adapts per-parameter learning rates, converging faster within the TTT budget. Controlled by TTT_OPTIMIZER env var (adam/adamw/sgd)."
}
Loading