@@ -103,6 +103,7 @@ class Hyperparameters:
103103 slot_lr = float (os .environ .get ("SLOT_LR" , 0.008 ))
104104 slot_lr_min = float (os .environ .get ("SLOT_LR_MIN" , 0.0008 ))
105105 slot_batch_seqs = int (os .environ .get ("SLOT_BATCH_SEQS" , 32 ))
106+ slot_warmstart = float (os .environ .get ("SLOT_WARMSTART" , 0.0 )) # 0=disabled; 0.85=warmstart from prev batch
106107 # Partial depth recurrence: repeat specified layers once more in the forward pass.
107108 # virtual_layers = [0,1,2,3,4,5,4,5,6,7,8,9,10] when recur_layers="4,5"
108109 recur_layers = os .environ .get ("RECUR_LAYERS" , "" ) # e.g., "4,5"
@@ -890,6 +891,9 @@ def eval_val_slot(
890891 loss_sum = torch .zeros ((), device = device , dtype = torch .float64 )
891892 token_count = torch .zeros ((), device = device , dtype = torch .float64 )
892893 byte_sum = torch .zeros ((), device = device , dtype = torch .float64 )
894+ warmstart_alpha = args .slot_warmstart
895+ prev_delta = None
896+ prev_bias = None
893897 base_model .eval ()
894898 for bi in range (0 , len (my_ws ), args .slot_batch_seqs ):
895899 bws = my_ws [bi :bi + args .slot_batch_seqs ]
@@ -916,8 +920,12 @@ def eval_val_slot(
916920 valid_count = mask .sum ()
917921 if valid_count == 0 :
918922 continue
919- delta = torch .zeros (bsz , 1 , hidden_f .size (- 1 ), device = device , dtype = torch .float32 , requires_grad = True )
920- logit_bias = torch .zeros (bsz , 1 , proj_w .size (0 ), device = device , dtype = torch .float32 , requires_grad = True )
923+ if warmstart_alpha > 0 and prev_delta is not None and prev_delta .size (0 ) == bsz :
924+ delta = (warmstart_alpha * prev_delta .detach ().clone ()).requires_grad_ (True )
925+ logit_bias = (warmstart_alpha * prev_bias .detach ().clone ()).requires_grad_ (True )
926+ else :
927+ delta = torch .zeros (bsz , 1 , hidden_f .size (- 1 ), device = device , dtype = torch .float32 , requires_grad = True )
928+ logit_bias = torch .zeros (bsz , 1 , proj_w .size (0 ), device = device , dtype = torch .float32 , requires_grad = True )
921929 slot_opt = torch .optim .AdamW ([delta , logit_bias ], lr = args .slot_lr , weight_decay = 1e-8 , eps = 1e-5 )
922930 targets_flat = yb .reshape (- 1 )
923931 for step_i in range (args .slot_steps ):
@@ -932,6 +940,9 @@ def eval_val_slot(
932940 slot_loss = (nll * mask ).sum () / valid_count
933941 slot_loss .backward ()
934942 slot_opt .step ()
943+ if warmstart_alpha > 0 :
944+ prev_delta = delta .detach ()
945+ prev_bias = logit_bias .detach ()
935946 with torch .no_grad ():
936947 h = hidden_f + delta .detach ()
937948 lp = F .linear (h , proj_w ) + logit_bias .detach ()
@@ -1496,7 +1507,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
14961507 torch .cuda .synchronize ()
14971508 log0 (
14981509 f"final_slot val_loss:{ slot_val_loss :.4f} val_bpb:{ slot_val_bpb :.4f} "
1499- f"steps:{ args .slot_steps } lr:{ args .slot_lr } eval_time:{ 1000.0 * (time .perf_counter () - t_slot ):.0f} ms"
1510+ f"steps:{ args .slot_steps } lr:{ args .slot_lr } warmstart: { args . slot_warmstart } eval_time:{ 1000.0 * (time .perf_counter () - t_slot ):.0f} ms"
15001511 )
15011512 log0 (f"final_slot_exact val_loss:{ slot_val_loss :.8f} val_bpb:{ slot_val_bpb :.8f} " )
15021513 if distributed :
0 commit comments