@@ -652,6 +652,7 @@ def store_memories(
652
652
653
653
# surprises
654
654
655
+ adaptive_lr = rearrange (adaptive_lr , '(b h n) c -> b h (n c)' , b = batch , h = heads )
655
656
unweighted_mem_model_loss = rearrange (unweighted_mem_model_loss , '(b h n) c -> b h (n c)' , b = batch , h = heads )
656
657
657
658
# maybe softclamp grad norm
@@ -695,7 +696,7 @@ def store_memories(
695
696
if not return_surprises :
696
697
return output
697
698
698
- return (* output , unweighted_mem_model_loss )
699
+ return (* output , ( unweighted_mem_model_loss , adaptive_lr ) )
699
700
700
701
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
701
702
@@ -755,7 +756,7 @@ def store_memories(
755
756
if not return_surprises :
756
757
return updates , next_store_state
757
758
758
- return updates , next_store_state , unweighted_mem_model_loss
759
+ return updates , next_store_state , ( unweighted_mem_model_loss , adaptive_lr )
759
760
760
761
def retrieve_memories (
761
762
self ,
@@ -939,7 +940,7 @@ def accum_updates(past_updates, future_updates):
939
940
940
941
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
941
942
942
- surprises = None
943
+ surprises = ( None , None )
943
944
gate = None
944
945
945
946
if exists (self .transition_gate ):
@@ -966,7 +967,7 @@ def accum_updates(past_updates, future_updates):
966
967
967
968
updates = accum_updates (updates , next_updates )
968
969
969
- surprises = safe_cat (( surprises , chunk_surprises ), dim = - 1 )
970
+ surprises = tuple ( safe_cat (args , dim = - 1 ) for args in zip ( surprises , chunk_surprises ) )
970
971
971
972
if is_last and not update_after_final_store :
972
973
continue
0 commit comments