Skip to content

Commit ddbeb48

Browse files
GitGeeksclaude
andcommitted
Add SLOT warmstart: inherit delta/bias from previous batch
Adjacent eval windows overlap ~96%. Warmstarting at alpha=0.85 means each window's SLOT optimization starts near the answer. SLOT_WARMSTART=0.85 enables it (default 0 = disabled). Based on PR openai#1319's approach (0.6951 BPB with warmstart + 64 steps). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e348405 commit ddbeb48

1 file changed

Lines changed: 14 additions & 3 deletions

File tree

train_gpt_slot_recurrence.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class Hyperparameters:
103103
slot_lr = float(os.environ.get("SLOT_LR", 0.008))
104104
slot_lr_min = float(os.environ.get("SLOT_LR_MIN", 0.0008))
105105
slot_batch_seqs = int(os.environ.get("SLOT_BATCH_SEQS", 32))
106+
slot_warmstart = float(os.environ.get("SLOT_WARMSTART", 0.0)) # 0=disabled; 0.85=warmstart from prev batch
106107
# Partial depth recurrence: repeat specified layers once more in the forward pass.
107108
# virtual_layers = [0,1,2,3,4,5,4,5,6,7,8,9,10] when recur_layers="4,5"
108109
recur_layers = os.environ.get("RECUR_LAYERS", "") # e.g., "4,5"
@@ -890,6 +891,9 @@ def eval_val_slot(
890891
loss_sum = torch.zeros((), device=device, dtype=torch.float64)
891892
token_count = torch.zeros((), device=device, dtype=torch.float64)
892893
byte_sum = torch.zeros((), device=device, dtype=torch.float64)
894+
warmstart_alpha = args.slot_warmstart
895+
prev_delta = None
896+
prev_bias = None
893897
base_model.eval()
894898
for bi in range(0, len(my_ws), args.slot_batch_seqs):
895899
bws = my_ws[bi:bi + args.slot_batch_seqs]
@@ -916,8 +920,12 @@ def eval_val_slot(
916920
valid_count = mask.sum()
917921
if valid_count == 0:
918922
continue
919-
delta = torch.zeros(bsz, 1, hidden_f.size(-1), device=device, dtype=torch.float32, requires_grad=True)
920-
logit_bias = torch.zeros(bsz, 1, proj_w.size(0), device=device, dtype=torch.float32, requires_grad=True)
923+
if warmstart_alpha > 0 and prev_delta is not None and prev_delta.size(0) == bsz:
924+
delta = (warmstart_alpha * prev_delta.detach().clone()).requires_grad_(True)
925+
logit_bias = (warmstart_alpha * prev_bias.detach().clone()).requires_grad_(True)
926+
else:
927+
delta = torch.zeros(bsz, 1, hidden_f.size(-1), device=device, dtype=torch.float32, requires_grad=True)
928+
logit_bias = torch.zeros(bsz, 1, proj_w.size(0), device=device, dtype=torch.float32, requires_grad=True)
921929
slot_opt = torch.optim.AdamW([delta, logit_bias], lr=args.slot_lr, weight_decay=1e-8, eps=1e-5)
922930
targets_flat = yb.reshape(-1)
923931
for step_i in range(args.slot_steps):
@@ -932,6 +940,9 @@ def eval_val_slot(
932940
slot_loss = (nll * mask).sum() / valid_count
933941
slot_loss.backward()
934942
slot_opt.step()
943+
if warmstart_alpha > 0:
944+
prev_delta = delta.detach()
945+
prev_bias = logit_bias.detach()
935946
with torch.no_grad():
936947
h = hidden_f + delta.detach()
937948
lp = F.linear(h, proj_w) + logit_bias.detach()
@@ -1496,7 +1507,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
14961507
torch.cuda.synchronize()
14971508
log0(
14981509
f"final_slot val_loss:{slot_val_loss:.4f} val_bpb:{slot_val_bpb:.4f} "
1499-
f"steps:{args.slot_steps} lr:{args.slot_lr} eval_time:{1000.0 * (time.perf_counter() - t_slot):.0f}ms"
1510+
f"steps:{args.slot_steps} lr:{args.slot_lr} warmstart:{args.slot_warmstart} eval_time:{1000.0 * (time.perf_counter() - t_slot):.0f}ms"
15001511
)
15011512
log0(f"final_slot_exact val_loss:{slot_val_loss:.8f} val_bpb:{slot_val_bpb:.8f}")
15021513
if distributed:

0 commit comments

Comments
 (0)