|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Legal Score-First TTT Eval for Clark's Model |
| 4 | +============================================== |
| 5 | +Loads a trained Clark model, adds LoRA adapters to Q and V, |
| 6 | +runs strict score-first TTT, reports BPB. |
| 7 | +
|
| 8 | +PROTOCOL (100% legal, same as PR #549 approved by valerio-oai): |
| 9 | + For each chunk: |
| 10 | + 1. SCORE: forward pass, compute loss (eval mode, no grad) |
| 11 | + 2. Record loss for BPB calculation |
| 12 | + 3. TRAIN: gradient update on scored chunk (AFTER scoring) |
| 13 | + 4. NEXT: use updated model for next chunk |
| 14 | +
|
| 15 | +USAGE on H100 (after Clark's train_gpt.py has trained a model): |
| 16 | + python3 clark_ttt_eval.py |
| 17 | +
|
| 18 | +Requires Clark's train_gpt.py in the same directory (as module). |
| 19 | +Loads model checkpoint from final_model.pt or trains briefly for testing. |
| 20 | +""" |
| 21 | +import sys; sys.stdout.reconfigure(line_buffering=True) |
| 22 | +sys.path.insert(0, '/workspace/repo') |
| 23 | + |
| 24 | +import os, time, math, copy |
| 25 | +os.chdir('/workspace/repo') |
| 26 | + |
| 27 | +import torch |
| 28 | +import torch.nn as nn |
| 29 | +import torch.nn.functional as F |
| 30 | +import numpy as np |
| 31 | +from pathlib import Path |
| 32 | + |
| 33 | +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| 34 | + |
| 35 | +print(f"Device: {DEVICE}") |
| 36 | +print(f"Legal Score-First TTT Eval — {time.strftime('%H:%M:%S')}") |
| 37 | + |
| 38 | +# ============================================================ |
| 39 | +# Load Clark's code as module |
| 40 | +# ============================================================ |
| 41 | +import train_gpt as tg |
| 42 | + |
| 43 | +# ============================================================ |
| 44 | +# LoRA wrapper |
| 45 | +# ============================================================ |
| 46 | +class LoRAWrapper(nn.Module): |
| 47 | + """Wraps a CastedLinear/Linear with LoRA. Only A and B are trainable.""" |
| 48 | + def __init__(self, base_linear, rank=8): |
| 49 | + super().__init__() |
| 50 | + self.base = base_linear |
| 51 | + in_f = base_linear.in_features |
| 52 | + out_f = base_linear.out_features |
| 53 | + self.scale = 1.0 / rank |
| 54 | + device = next(base_linear.parameters()).device |
| 55 | + self.lora_A = nn.Parameter(torch.randn(in_f, rank, device=device) * 0.01) |
| 56 | + self.lora_B = nn.Parameter(torch.zeros(rank, out_f, device=device)) |
| 57 | + for p in self.base.parameters(): |
| 58 | + p.requires_grad = False |
| 59 | + |
| 60 | + def forward(self, x): |
| 61 | + return self.base(x) + (x @ self.lora_A @ self.lora_B) * self.scale |
| 62 | + |
| 63 | + @property |
| 64 | + def in_features(self): |
| 65 | + return self.base.in_features |
| 66 | + |
| 67 | + @property |
| 68 | + def out_features(self): |
| 69 | + return self.base.out_features |
| 70 | + |
| 71 | + @property |
| 72 | + def weight(self): |
| 73 | + return self.base.weight |
| 74 | + |
| 75 | + |
| 76 | +def add_lora(model, rank=8): |
| 77 | + """Add LoRA to Q and V projections in all attention blocks. |
| 78 | + Freeze all base params. Returns list of LoRA parameters.""" |
| 79 | + for p in model.parameters(): |
| 80 | + p.requires_grad = False |
| 81 | + |
| 82 | + lora_params = [] |
| 83 | + for block in model.blocks: |
| 84 | + attn = block.attn |
| 85 | + # Wrap c_q |
| 86 | + lora_q = LoRAWrapper(attn.c_q, rank=rank) |
| 87 | + attn.c_q = lora_q |
| 88 | + lora_params.extend([lora_q.lora_A, lora_q.lora_B]) |
| 89 | + # Wrap c_v |
| 90 | + lora_v = LoRAWrapper(attn.c_v, rank=rank) |
| 91 | + attn.c_v = lora_v |
| 92 | + lora_params.extend([lora_v.lora_A, lora_v.lora_B]) |
| 93 | + |
| 94 | + n_lora = sum(p.numel() for p in lora_params) |
| 95 | + print(f" LoRA: rank={rank}, {n_lora:,} params on Q,V in {len(model.blocks)} layers") |
| 96 | + return lora_params |
| 97 | + |
| 98 | + |
| 99 | +# ============================================================ |
| 100 | +# Score-First TTT |
| 101 | +# ============================================================ |
| 102 | +def score_first_ttt(model, val_tokens, lora_params, h, |
| 103 | + chunk_size=2048, epochs=3, lr=0.001, |
| 104 | + byte_luts=None): |
| 105 | + """Strict score-first TTT. Score chunk → record loss → train on it → next chunk.""" |
| 106 | + optimizer = torch.optim.AdamW(lora_params, lr=lr, betas=(0.9, 0.95)) |
| 107 | + |
| 108 | + n_tokens = val_tokens.numel() |
| 109 | + n_chunks = (n_tokens - 1) // chunk_size |
| 110 | + vocab_size = h.vocab_size |
| 111 | + |
| 112 | + total_nll = 0.0 |
| 113 | + total_scored = 0 |
| 114 | + total_bytes = 0.0 |
| 115 | + t0 = time.time() |
| 116 | + |
| 117 | + for c in range(n_chunks): |
| 118 | + start = c * chunk_size |
| 119 | + end = min(start + chunk_size + 1, n_tokens) |
| 120 | + chunk = val_tokens[start:end].to(device=DEVICE, dtype=torch.long) |
| 121 | + if len(chunk) < 2: |
| 122 | + continue |
| 123 | + |
| 124 | + x = chunk[:-1].unsqueeze(0) |
| 125 | + y = chunk[1:].unsqueeze(0) |
| 126 | + n_tok = y.numel() |
| 127 | + |
| 128 | + # === STEP 1: SCORE (eval mode, no gradients) === |
| 129 | + model.eval() |
| 130 | + with torch.no_grad(): |
| 131 | + with torch.autocast("cuda", torch.bfloat16): |
| 132 | + logits = model.forward_logits(x) |
| 133 | + loss = F.cross_entropy(logits.float().reshape(-1, vocab_size), y.reshape(-1)) |
| 134 | + |
| 135 | + total_nll += loss.item() * n_tok |
| 136 | + total_scored += n_tok |
| 137 | + |
| 138 | + # Byte counting for BPB |
| 139 | + if byte_luts is not None: |
| 140 | + base_lut, space_lut, boundary_lut = byte_luts |
| 141 | + prev_ids = x.reshape(-1) |
| 142 | + tgt_ids = y.reshape(-1) |
| 143 | + tb = base_lut[tgt_ids].to(torch.int16) |
| 144 | + tb += (space_lut[tgt_ids] & ~boundary_lut[prev_ids]).to(torch.int16) |
| 145 | + total_bytes += tb.float().sum().item() |
| 146 | + |
| 147 | + # === STEP 2: TRAIN on scored chunk (AFTER scoring) === |
| 148 | + if c < n_chunks - 1: |
| 149 | + model.train() |
| 150 | + for ep in range(epochs): |
| 151 | + with torch.autocast("cuda", torch.bfloat16): |
| 152 | + logits = model.forward_logits(x) |
| 153 | + train_loss = F.cross_entropy(logits.float().reshape(-1, vocab_size), y.reshape(-1)) |
| 154 | + optimizer.zero_grad() |
| 155 | + train_loss.backward() |
| 156 | + torch.nn.utils.clip_grad_norm_(lora_params, 1.0) |
| 157 | + optimizer.step() |
| 158 | + |
| 159 | + # Progress |
| 160 | + if (c + 1) % 50 == 0 or c == n_chunks - 1: |
| 161 | + avg_loss = total_nll / total_scored |
| 162 | + if total_bytes > 0: |
| 163 | + bpb = (avg_loss / math.log(2)) * (total_scored / total_bytes) |
| 164 | + print(f" TTT [{c+1}/{n_chunks}] loss={avg_loss:.4f} bpb={bpb:.4f} ({time.time()-t0:.0f}s)", flush=True) |
| 165 | + else: |
| 166 | + print(f" TTT [{c+1}/{n_chunks}] loss={avg_loss:.4f} ({time.time()-t0:.0f}s)", flush=True) |
| 167 | + |
| 168 | + avg_loss = total_nll / total_scored |
| 169 | + if total_bytes > 0: |
| 170 | + bpb = (avg_loss / math.log(2)) * (total_scored / total_bytes) |
| 171 | + else: |
| 172 | + bpb = avg_loss / math.log(2) |
| 173 | + |
| 174 | + return avg_loss, bpb, time.time() - t0 |
| 175 | + |
| 176 | + |
| 177 | +# ============================================================ |
| 178 | +# Main |
| 179 | +# ============================================================ |
| 180 | +if __name__ == "__main__": |
| 181 | + print("\n=== Building model ===") |
| 182 | + h = tg.Hyperparameters() |
| 183 | + |
| 184 | + # Load tokenizer + byte LUTs |
| 185 | + import sentencepiece as spm |
| 186 | + sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) |
| 187 | + byte_luts = tg.build_sentencepiece_luts(sp, h.vocab_size, torch.device(DEVICE)) |
| 188 | + |
| 189 | + # Load validation tokens — h.val_files is a glob pattern STRING |
| 190 | + val_tokens = tg.load_validation_tokens(h.val_files, h.eval_seq_len) |
| 191 | + print(f"Val tokens: {val_tokens.numel():,}") |
| 192 | + |
| 193 | + # Build model |
| 194 | + model = tg.GPT(h).to(DEVICE) |
| 195 | + n_params = sum(p.numel() for p in model.parameters()) |
| 196 | + print(f"Model: {n_params:,} params") |
| 197 | + |
| 198 | + # Load checkpoint if available, else quick train |
| 199 | + ckpt_path = Path("final_model.pt") |
| 200 | + if ckpt_path.exists(): |
| 201 | + print(f"Loading checkpoint from {ckpt_path}...") |
| 202 | + state = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) |
| 203 | + model.load_state_dict(state, strict=False) |
| 204 | + print("Checkpoint loaded") |
| 205 | + else: |
| 206 | + print("\n=== No checkpoint — quick training (200 steps) ===") |
| 207 | + train_files = sorted(Path(h.datasets_dir).glob("fineweb_train_*.bin")) |
| 208 | + if not train_files: |
| 209 | + print("ERROR: No training data found") |
| 210 | + sys.exit(1) |
| 211 | + train_shard = tg.load_data_shard(train_files[0]) |
| 212 | + optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=h.muon_wd) |
| 213 | + model.train() |
| 214 | + for step in range(200): |
| 215 | + start_idx = step * h.train_seq_len * 8 |
| 216 | + if start_idx + h.train_seq_len * 8 + 1 > train_shard.numel(): |
| 217 | + start_idx = 0 |
| 218 | + chunk = train_shard[start_idx:start_idx + h.train_seq_len * 8 + 1].to(DEVICE, torch.long) |
| 219 | + x = chunk[:-1].reshape(-1, h.train_seq_len)[:8] |
| 220 | + y = chunk[1:].reshape(-1, h.train_seq_len)[:8] |
| 221 | + with torch.autocast("cuda", torch.bfloat16): |
| 222 | + loss = model(x, y) |
| 223 | + optimizer.zero_grad(); loss.backward() |
| 224 | + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| 225 | + optimizer.step() |
| 226 | + if step % 100 == 0: |
| 227 | + print(f" Step {step}: loss={loss.item():.4f}") |
| 228 | + |
| 229 | + # === Eval WITHOUT TTT === |
| 230 | + print("\n=== Eval WITHOUT TTT ===") |
| 231 | + model.eval() |
| 232 | + n_eval = min(500000, val_tokens.numel() - 1) |
| 233 | + chunk_size = h.eval_seq_len |
| 234 | + n_chunks = n_eval // chunk_size |
| 235 | + |
| 236 | + base_lut, space_lut, boundary_lut = byte_luts |
| 237 | + total_nll = 0.0; total_tok = 0; total_bytes = 0.0 |
| 238 | + |
| 239 | + with torch.no_grad(): |
| 240 | + for c in range(n_chunks): |
| 241 | + s = c * chunk_size |
| 242 | + chunk = val_tokens[s:s + chunk_size + 1].to(DEVICE, torch.long) |
| 243 | + x = chunk[:-1].unsqueeze(0) |
| 244 | + y = chunk[1:].unsqueeze(0) |
| 245 | + with torch.autocast("cuda", torch.bfloat16): |
| 246 | + logits = model.forward_logits(x) |
| 247 | + loss = F.cross_entropy(logits.float().reshape(-1, h.vocab_size), y.reshape(-1)) |
| 248 | + total_nll += loss.item() * y.numel() |
| 249 | + total_tok += y.numel() |
| 250 | + tb = base_lut[y.reshape(-1)].to(torch.int16) |
| 251 | + tb += (space_lut[y.reshape(-1)] & ~boundary_lut[x.reshape(-1)]).to(torch.int16) |
| 252 | + total_bytes += tb.float().sum().item() |
| 253 | + |
| 254 | + pre_loss = total_nll / total_tok |
| 255 | + pre_bpb = (pre_loss / math.log(2)) * (total_tok / total_bytes) |
| 256 | + print(f"Pre-TTT: loss={pre_loss:.4f} bpb={pre_bpb:.4f} ({total_tok:,} tokens)") |
| 257 | + |
| 258 | + # === Add LoRA + Run TTT === |
| 259 | + print("\n=== Score-First TTT (LoRA rank=8) ===") |
| 260 | + ttt_model = copy.deepcopy(model) |
| 261 | + lora_params = add_lora(ttt_model, rank=8) |
| 262 | + |
| 263 | + ttt_loss, ttt_bpb, ttt_time = score_first_ttt( |
| 264 | + ttt_model, val_tokens[:n_eval + 1], lora_params, h, |
| 265 | + chunk_size=chunk_size, epochs=3, lr=0.001, |
| 266 | + byte_luts=byte_luts |
| 267 | + ) |
| 268 | + |
| 269 | + # === Results === |
| 270 | + improvement = (ttt_bpb - pre_bpb) / pre_bpb * 100 |
| 271 | + print(f"\n{'='*60}") |
| 272 | + print(f"RESULTS") |
| 273 | + print(f"{'='*60}") |
| 274 | + print(f"Pre-TTT: loss={pre_loss:.4f} bpb={pre_bpb:.4f}") |
| 275 | + print(f"Post-TTT: loss={ttt_loss:.4f} bpb={ttt_bpb:.4f}") |
| 276 | + print(f"Change: {improvement:+.2f}%") |
| 277 | + print(f"TTT time: {ttt_time:.0f}s") |
| 278 | + print(f"Tokens: {total_tok:,}") |
0 commit comments