Skip to content

Commit 7ec702f

Browse files
committed
return the adaptive lr alongside the mem model losses, as model could have chosen not to store anything
1 parent 9a10ee9 commit 7ec702f

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "titans-pytorch"
3-
version = "0.4.1"
3+
version = "0.4.3"
44
description = "Titans"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_titans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def test_return_surprises():
8585

8686
seq = torch.randn(4, 64, 384)
8787

88-
_, _, surprises = mem(seq, return_surprises = True)
88+
_, _, (surprises, adaptive_lr) = mem(seq, return_surprises = True)
8989

90-
assert surprises.shape == (4, 4, 64)
90+
assert all([t.shape == (4, 4, 64) for t in (surprises, adaptive_lr)])
9191

9292
@pytest.mark.parametrize('learned_momentum_combine', (False, True))
9393
@pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))

titans_pytorch/neural_memory.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def store_memories(
652652

653653
# surprises
654654

655+
adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
655656
unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
656657

657658
# maybe softclamp grad norm
@@ -695,7 +696,7 @@ def store_memories(
695696
if not return_surprises:
696697
return output
697698

698-
return (*output, unweighted_mem_model_loss)
699+
return (*output, (unweighted_mem_model_loss, adaptive_lr))
699700

700701
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
701702

@@ -755,7 +756,7 @@ def store_memories(
755756
if not return_surprises:
756757
return updates, next_store_state
757758

758-
return updates, next_store_state, unweighted_mem_model_loss
759+
return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
759760

760761
def retrieve_memories(
761762
self,
@@ -939,7 +940,7 @@ def accum_updates(past_updates, future_updates):
939940

940941
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
941942

942-
surprises = None
943+
surprises = (None, None)
943944
gate = None
944945

945946
if exists(self.transition_gate):
@@ -966,7 +967,7 @@ def accum_updates(past_updates, future_updates):
966967

967968
updates = accum_updates(updates, next_updates)
968969

969-
surprises = safe_cat((surprises, chunk_surprises), dim = -1)
970+
surprises = tuple(safe_cat(args, dim = -1) for args in zip(surprises, chunk_surprises))
970971

971972
if is_last and not update_after_final_store:
972973
continue

0 commit comments

Comments
 (0)