@@ -123,6 +123,16 @@ class Hyperparameters:
123123 slot_steps = int (os .environ .get ("SLOT_STEPS" , 16 ))
124124 slot_lr = float (os .environ .get ("SLOT_LR" , 0.008 ))
125125 slot_lr_min = float (os .environ .get ("SLOT_LR_MIN" , 0.0008 ))
126+ # Depth Recurrence (PR #1517 / #1331 / #1471 pattern): re-run selected encoder
127+ # layers once more after encoder pass, NO skip push (preserves U-Net symmetry).
128+ # Format: "2,3,4" (comma-separated layer indices); empty string = disabled.
129+ recur_layers_str = os .environ .get ("RECUR_LAYERS" , "" )
130+ recur_start_step = int (os .environ .get ("RECUR_START_STEP" , 2000 ))
131+ # Lookahead Optimizer (Zhang/Lucas/Hinton/Ba, NeurIPS 2019) - Aweb signature.
132+ # Maintains slow weights; every k inner steps: slow = (1-a)*slow + a*fast; fast := slow.
133+ lookahead_enabled = bool (int (os .environ .get ("LOOKAHEAD_ENABLED" , "1" )))
134+ lookahead_k = int (os .environ .get ("LOOKAHEAD_K" , 5 ))
135+ lookahead_alpha = float (os .environ .get ("LOOKAHEAD_ALPHA" , 0.5 ))
126136 # N-gram oracle mixing
127137 ngram_enabled = bool (int (os .environ .get ("NGRAM_ENABLED" , "1" )))
128138 ngram_max_order = int (os .environ .get ("NGRAM_MAX_ORDER" , 12 )) # order 2-12 backoff
@@ -829,8 +839,17 @@ def __init__(
829839 ve_layers : str = "9,10" ,
830840 gated_attention : bool = False ,
831841 value_residual : bool = False ,
842+ recur_layers_str : str = "" ,
832843 ):
833844 super ().__init__ ()
845+ # Depth recurrence: list of layer indices to re-run after encoder pass.
846+ # No skip-stack push on recurrence (preserves U-Net 5-in/5-out symmetry).
847+ # `recur_enabled` is toggled by the training loop at recur_start_step.
848+ if recur_layers_str .strip ():
849+ self .recur_layers = [int (x ) for x in recur_layers_str .split ("," ) if x .strip ()]
850+ else :
851+ self .recur_layers = []
852+ self .recur_enabled = False # set True by training loop after warmup; True at eval
834853 self ._ve_target_dim = num_kv_heads * (model_dim // num_heads ) # kv_dim for value projection
835854 if logit_softcap <= 0.0 :
836855 raise ValueError (f"logit_softcap must be positive, got { logit_softcap } " )
@@ -955,6 +974,15 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
955974 if v0 is None and raw_v is not None :
956975 v0 = raw_v
957976 skips .append (x )
977+ # Depth recurrence: re-run selected encoder layers once more (no skip push).
978+ # Preserves U-Net symmetry (skips list length unchanged).
979+ if self .recur_enabled and self .recur_layers :
980+ for ri in self .recur_layers :
981+ ve = self ._get_ve (ri , input_ids , ve_cache )
982+ x , _ = self .blocks [ri ](x , x0 ,
983+ self .qo_bank [ri ], self .kv_bank [ri ], self .kv_bank [n + ri ],
984+ self .qo_bank [n + ri ], self .mlp_up_bank [ri ], self .mlp_down_bank [ri ],
985+ v_embed = ve , v0 = v0 )
958986 for i in range (self .num_decoder_layers ):
959987 bi = self .num_encoder_layers + i
960988 if skips :
@@ -1013,6 +1041,14 @@ def forward_hidden(self, input_ids: Tensor) -> Tensor:
10131041 if v0 is None and raw_v is not None :
10141042 v0 = raw_v
10151043 skips .append (x )
1044+ # Depth recurrence (mirrors forward()): no skip push, preserves U-Net.
1045+ if self .recur_enabled and self .recur_layers :
1046+ for ri in self .recur_layers :
1047+ ve = self ._get_ve (ri , input_ids , ve_cache )
1048+ x , _ = self .blocks [ri ](x , x0 ,
1049+ self .qo_bank [ri ], self .kv_bank [ri ], self .kv_bank [n + ri ],
1050+ self .qo_bank [n + ri ], self .mlp_up_bank [ri ], self .mlp_down_bank [ri ],
1051+ v_embed = ve , v0 = v0 )
10161052 for i in range (self .num_decoder_layers ):
10171053 bi = self .num_encoder_layers + i
10181054 if skips :
@@ -1918,6 +1954,7 @@ def log0(msg: str, console: bool = True) -> None:
19181954 ve_layers = args .ve_layers ,
19191955 gated_attention = args .gated_attention ,
19201956 value_residual = args .value_residual ,
1957+ recur_layers_str = args .recur_layers_str ,
19211958 ).to (device ).bfloat16 ()
19221959 # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward
19231960 base_model .qo_bank .data = base_model .qo_bank .data .float ()
@@ -2074,6 +2111,21 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
20742111 lawa_queue : deque [dict [str , Tensor ]] = deque (maxlen = args .lawa_k )
20752112 ema_state = {name : t .detach ().float ().clone () for name , t in base_model .state_dict ().items ()}
20762113 ema_decay = 0.997
2114+ # Lookahead Optimizer (Aweb signature, Zhang et al. NeurIPS 2019).
2115+ # Slow weights snapshot for ALL trainable params; updated every k inner steps.
2116+ lookahead_slow : dict [str , Tensor ] | None = None
2117+ if args .lookahead_enabled :
2118+ lookahead_slow = {n : p .detach ().float ().clone ()
2119+ for n , p in base_model .named_parameters () if p .requires_grad }
2120+ if rank == 0 :
2121+ log0 (f"lookahead:enabled k={ args .lookahead_k } alpha={ args .lookahead_alpha } "
2122+ f"slow_params={ len (lookahead_slow )} " )
2123+ elif rank == 0 :
2124+ log0 ("lookahead:disabled" )
2125+ if base_model .recur_layers and rank == 0 :
2126+ log0 (f"recur:configured layers={ base_model .recur_layers } start_step={ args .recur_start_step } " )
2127+ elif rank == 0 :
2128+ log0 ("recur:disabled" )
20772129 training_time_ms = 0.0
20782130 stop_after_step : int | None = None
20792131 torch .cuda .synchronize ()
@@ -2153,6 +2205,22 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
21532205 for name , t in base_model .state_dict ().items ():
21542206 ema_state [name ].mul_ (ema_decay ).add_ (t .detach ().float (), alpha = 1.0 - ema_decay )
21552207 step += 1
2208+ # Lookahead update (Aweb signature): every k inner steps,
2209+ # slow := (1-a)*slow + a*fast; fast := slow.
2210+ if lookahead_slow is not None and step % args .lookahead_k == 0 :
2211+ with torch .no_grad ():
2212+ for n , p in base_model .named_parameters ():
2213+ if n in lookahead_slow :
2214+ slow = lookahead_slow [n ]
2215+ slow .mul_ (1.0 - args .lookahead_alpha ).add_ (p .data .float (),
2216+ alpha = args .lookahead_alpha )
2217+ p .data .copy_ (slow .to (p .dtype ))
2218+ # Depth recurrence curriculum: enable after warmup steps.
2219+ if (base_model .recur_layers and not base_model .recur_enabled
2220+ and step >= args .recur_start_step ):
2221+ base_model .recur_enabled = True
2222+ if rank == 0 :
2223+ log0 (f"recur:activated step={ step } layers={ base_model .recur_layers } " )
21562224 approx_training_time_ms = training_time_ms + 1000.0 * (time .perf_counter () - t0 )
21572225 if args .swa_enabled and scale < 0.2 and step % args .swa_every == 0 :
21582226 if swa_state is None :
@@ -2264,7 +2332,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
22642332 rope_dims = args .rope_dims , ln_scale = args .ln_scale , dtg = args .dtg_enabled ,
22652333 ve_enabled = args .ve_enabled , ve_dim = args .ve_dim , ve_layers = args .ve_layers ,
22662334 gated_attention = args .gated_attention , value_residual = args .value_residual ,
2335+ recur_layers_str = args .recur_layers_str ,
22672336 ).to (device ).bfloat16 ()
2337+ # Eval model always uses recurrence (if configured) — no curriculum at eval time.
2338+ eval_model .recur_enabled = bool (eval_model .recur_layers )
22682339 eval_model .qo_bank .data = eval_model .qo_bank .data .float ()
22692340 eval_model .kv_bank .data = eval_model .kv_bank .data .float ()
22702341 eval_model .mlp_up_bank .data = eval_model .mlp_up_bank .data .float ()
0 commit comments