Skip to content

Commit 5fc5ab5

Browse files
feat: Aweb Signature stack — depth recurrence + lookahead optimizer
Two complementary additions to AwebUltimate base: 1. Depth Recurrence (PR openai#1517/openai#1331/openai#1471 pattern, must credit): Re-run selected encoder layers once more after encoder pass, with NO skip-stack push (preserves U-Net 5-in/5-out symmetry). Curriculum: activated at RECUR_START_STEP (default 2000). Eval always uses recurrence. Env vars: RECUR_LAYERS (e.g. '3,4'), RECUR_START_STEP. 2. Lookahead Optimizer (Zhang/Lucas/Hinton/Ba, NeurIPS 2019) — Aweb signature: Maintains slow weights for all trainable params. Every k inner steps: slow := (1-α)*slow + α*fast; fast := slow. ~5%% wall-clock overhead. Novel for nanochat speedrun (verified via gh search). Env vars: LOOKAHEAD_ENABLED, LOOKAHEAD_K, LOOKAHEAD_ALPHA. Backwards-compat: RECUR_LAYERS='' (default) + LOOKAHEAD_ENABLED=0 reproduces proven 1.1190 baseline byte-identically. CPU smoke test (10 cases) PASSES: env wiring, model construction, recur parsing, forward/forward_hidden recur application, skip-stack symmetry, mini-training loss decrease, lookahead update math.
1 parent 16d2f2f commit 5fc5ab5

1 file changed

Lines changed: 71 additions & 0 deletions

File tree

  • records/track_10min_16mb/2026-03-23_AwebUltimate

records/track_10min_16mb/2026-03-23_AwebUltimate/train_gpt.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,16 @@ class Hyperparameters:
123123
slot_steps = int(os.environ.get("SLOT_STEPS", 16))
124124
slot_lr = float(os.environ.get("SLOT_LR", 0.008))
125125
slot_lr_min = float(os.environ.get("SLOT_LR_MIN", 0.0008))
126+
# Depth Recurrence (PR #1517 / #1331 / #1471 pattern): re-run selected encoder
127+
# layers once more after encoder pass, NO skip push (preserves U-Net symmetry).
128+
# Format: "2,3,4" (comma-separated layer indices); empty string = disabled.
129+
recur_layers_str = os.environ.get("RECUR_LAYERS", "")
130+
recur_start_step = int(os.environ.get("RECUR_START_STEP", 2000))
131+
# Lookahead Optimizer (Zhang/Lucas/Hinton/Ba, NeurIPS 2019) - Aweb signature.
132+
# Maintains slow weights; every k inner steps: slow = (1-a)*slow + a*fast; fast := slow.
133+
lookahead_enabled = bool(int(os.environ.get("LOOKAHEAD_ENABLED", "1")))
134+
lookahead_k = int(os.environ.get("LOOKAHEAD_K", 5))
135+
lookahead_alpha = float(os.environ.get("LOOKAHEAD_ALPHA", 0.5))
126136
# N-gram oracle mixing
127137
ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "1")))
128138
ngram_max_order = int(os.environ.get("NGRAM_MAX_ORDER", 12)) # order 2-12 backoff
@@ -829,8 +839,17 @@ def __init__(
829839
ve_layers: str = "9,10",
830840
gated_attention: bool = False,
831841
value_residual: bool = False,
842+
recur_layers_str: str = "",
832843
):
833844
super().__init__()
845+
# Depth recurrence: list of layer indices to re-run after encoder pass.
846+
# No skip-stack push on recurrence (preserves U-Net 5-in/5-out symmetry).
847+
# `recur_enabled` is toggled by the training loop at recur_start_step.
848+
if recur_layers_str.strip():
849+
self.recur_layers = [int(x) for x in recur_layers_str.split(",") if x.strip()]
850+
else:
851+
self.recur_layers = []
852+
self.recur_enabled = False # set True by training loop after warmup; True at eval
834853
self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection
835854
if logit_softcap <= 0.0:
836855
raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")
@@ -955,6 +974,15 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
955974
if v0 is None and raw_v is not None:
956975
v0 = raw_v
957976
skips.append(x)
977+
# Depth recurrence: re-run selected encoder layers once more (no skip push).
978+
# Preserves U-Net symmetry (skips list length unchanged).
979+
if self.recur_enabled and self.recur_layers:
980+
for ri in self.recur_layers:
981+
ve = self._get_ve(ri, input_ids, ve_cache)
982+
x, _ = self.blocks[ri](x, x0,
983+
self.qo_bank[ri], self.kv_bank[ri], self.kv_bank[n + ri],
984+
self.qo_bank[n + ri], self.mlp_up_bank[ri], self.mlp_down_bank[ri],
985+
v_embed=ve, v0=v0)
958986
for i in range(self.num_decoder_layers):
959987
bi = self.num_encoder_layers + i
960988
if skips:
@@ -1013,6 +1041,14 @@ def forward_hidden(self, input_ids: Tensor) -> Tensor:
10131041
if v0 is None and raw_v is not None:
10141042
v0 = raw_v
10151043
skips.append(x)
1044+
# Depth recurrence (mirrors forward()): no skip push, preserves U-Net.
1045+
if self.recur_enabled and self.recur_layers:
1046+
for ri in self.recur_layers:
1047+
ve = self._get_ve(ri, input_ids, ve_cache)
1048+
x, _ = self.blocks[ri](x, x0,
1049+
self.qo_bank[ri], self.kv_bank[ri], self.kv_bank[n + ri],
1050+
self.qo_bank[n + ri], self.mlp_up_bank[ri], self.mlp_down_bank[ri],
1051+
v_embed=ve, v0=v0)
10161052
for i in range(self.num_decoder_layers):
10171053
bi = self.num_encoder_layers + i
10181054
if skips:
@@ -1918,6 +1954,7 @@ def log0(msg: str, console: bool = True) -> None:
19181954
ve_layers=args.ve_layers,
19191955
gated_attention=args.gated_attention,
19201956
value_residual=args.value_residual,
1957+
recur_layers_str=args.recur_layers_str,
19211958
).to(device).bfloat16()
19221959
# Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward
19231960
base_model.qo_bank.data = base_model.qo_bank.data.float()
@@ -2074,6 +2111,21 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
20742111
lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k)
20752112
ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()}
20762113
ema_decay = 0.997
2114+
# Lookahead Optimizer (Aweb signature, Zhang et al. NeurIPS 2019).
2115+
# Slow weights snapshot for ALL trainable params; updated every k inner steps.
2116+
lookahead_slow: dict[str, Tensor] | None = None
2117+
if args.lookahead_enabled:
2118+
lookahead_slow = {n: p.detach().float().clone()
2119+
for n, p in base_model.named_parameters() if p.requires_grad}
2120+
if rank == 0:
2121+
log0(f"lookahead:enabled k={args.lookahead_k} alpha={args.lookahead_alpha} "
2122+
f"slow_params={len(lookahead_slow)}")
2123+
elif rank == 0:
2124+
log0("lookahead:disabled")
2125+
if base_model.recur_layers and rank == 0:
2126+
log0(f"recur:configured layers={base_model.recur_layers} start_step={args.recur_start_step}")
2127+
elif rank == 0:
2128+
log0("recur:disabled")
20772129
training_time_ms = 0.0
20782130
stop_after_step: int | None = None
20792131
torch.cuda.synchronize()
@@ -2153,6 +2205,22 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
21532205
for name, t in base_model.state_dict().items():
21542206
ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay)
21552207
step += 1
2208+
# Lookahead update (Aweb signature): every k inner steps,
2209+
# slow := (1-a)*slow + a*fast; fast := slow.
2210+
if lookahead_slow is not None and step % args.lookahead_k == 0:
2211+
with torch.no_grad():
2212+
for n, p in base_model.named_parameters():
2213+
if n in lookahead_slow:
2214+
slow = lookahead_slow[n]
2215+
slow.mul_(1.0 - args.lookahead_alpha).add_(p.data.float(),
2216+
alpha=args.lookahead_alpha)
2217+
p.data.copy_(slow.to(p.dtype))
2218+
# Depth recurrence curriculum: enable after warmup steps.
2219+
if (base_model.recur_layers and not base_model.recur_enabled
2220+
and step >= args.recur_start_step):
2221+
base_model.recur_enabled = True
2222+
if rank == 0:
2223+
log0(f"recur:activated step={step} layers={base_model.recur_layers}")
21562224
approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0)
21572225
if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0:
21582226
if swa_state is None:
@@ -2264,7 +2332,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
22642332
rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled,
22652333
ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers,
22662334
gated_attention=args.gated_attention, value_residual=args.value_residual,
2335+
recur_layers_str=args.recur_layers_str,
22672336
).to(device).bfloat16()
2337+
# Eval model always uses recurrence (if configured) — no curriculum at eval time.
2338+
eval_model.recur_enabled = bool(eval_model.recur_layers)
22682339
eval_model.qo_bank.data = eval_model.qo_bank.data.float()
22692340
eval_model.kv_bank.data = eval_model.kv_bank.data.float()
22702341
eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float()

0 commit comments

Comments
 (0)