Skip to content

Commit fbfdd47

Browse files
committed
generalize all possible ways to graft the neural memory onto a transformer
1 parent 55b4d71 commit fbfdd47

File tree

3 files changed

+59
-14
lines changed

3 files changed

+59
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "titans-pytorch"
3-
version = "0.3.25"
3+
version = "0.4.0"
44
description = "Titans"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

titans_pytorch/mac_transformer.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_mac_mask(_, __, q_idx, kv_idx):
4646

4747
# einstein notation related
4848

49-
from einops import repeat, rearrange, pack, unpack
49+
from einops import repeat, rearrange, pack, unpack, einsum
5050
from einops.layers.torch import Rearrange
5151

5252
# b - batch
@@ -521,9 +521,7 @@ def __init__(
521521
self.sliding_window_attn = sliding_window_attn
522522
self.attn_window_size = segment_len + num_longterm_mem_tokens
523523

524-
# hyper conection
525-
526-
assert not (num_residual_streams <= 1 and neural_memory_qkv_receives_diff_views), 'allow neural memory queries, keys, values to be derived from different combinations of the residual streams can only work if hyper connections has greater than 1 residual stream'
524+
# hyper connection
527525

528526
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
529527

@@ -560,17 +558,28 @@ def __init__(
560558
)
561559

562560
mem = None
561+
mem_qkv_layer_selector = None
563562
mem_hyper_conn = None
564563

565564
if layer in neural_memory_layers:
566-
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
565+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
566+
567+
if not is_first and neural_memory_qkv_receives_diff_views:
568+
num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input)
569+
570+
mem_qkv_layer_selector = nn.Sequential(
571+
nn.RMSNorm(dim),
572+
nn.Linear(dim, 3 * num_layer_choices),
573+
Rearrange('... (views layers) -> views ... layers', views = 3),
574+
nn.Softmax(dim = -1)
575+
)
567576

568577
mem = NeuralMemory(
569578
dim = dim,
570579
chunk_size = self.neural_memory_segment_len,
571580
batch_size = neural_memory_batch_size,
572581
model = deepcopy(neural_memory_model),
573-
qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
582+
qkv_receives_diff_views = True,
574583
accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
575584
**neural_memory_kwargs
576585
)
@@ -581,9 +590,12 @@ def __init__(
581590

582591
self.layers.append(ModuleList([
583592
mem_hyper_conn,
593+
init_hyper_conn(),
594+
init_hyper_conn(),
595+
mem_qkv_layer_selector,
584596
mem,
585-
init_hyper_conn(branch = attn),
586-
init_hyper_conn(branch = ff)
597+
attn,
598+
ff,
587599
]))
588600

589601
self.norm = nn.RMSNorm(dim)
@@ -763,6 +775,10 @@ def forward(
763775

764776
mem_weight_residual = None
765777

778+
# layers for the neural mem to select the qkv inputs from
779+
780+
mem_input_layers = []
781+
766782
# when inferencing, only do one token at a time
767783

768784
if is_inferencing:
@@ -773,7 +789,7 @@ def forward(
773789

774790
x = self.expand_streams(x)
775791

776-
for mem_hyper_conn, mem, attn, ff in self.layers:
792+
for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers:
777793

778794
retrieved = None
779795
attn_out_gates = None
@@ -785,8 +801,19 @@ def forward(
785801

786802
mem_input, add_residual = mem_hyper_conn(x)
787803

804+
if not exists(mem_qkv_layer_selector):
805+
qkv_mem_input = stack((mem_input, mem_input, mem_input))
806+
else:
807+
layers_to_choose_from = stack((mem_input, *mem_input_layers))
808+
809+
# let the current `mem_input` select the 3 layers for qkv
810+
811+
selected = mem_qkv_layer_selector(mem_input)
812+
813+
qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d')
814+
788815
retrieved, next_neural_mem_cache = mem.forward(
789-
mem_input,
816+
qkv_mem_input,
790817
state = next(neural_mem_caches, None),
791818
prev_weights = mem_weight_residual
792819
)
@@ -801,25 +828,41 @@ def forward(
801828

802829
# attention
803830

804-
x, (values, next_kv_cache) = attn(
805-
x,
831+
attn_in, add_residual = attn_hyper_conn(x)
832+
833+
mem_input_layers.append(attn_in)
834+
835+
attn_out, (values, next_kv_cache) = attn(
836+
attn_in,
806837
value_residual = value_residual,
807838
disable_flex_attn = disable_flex_attn,
808839
flex_attn_fn = flex_attn_fn,
809840
output_gating = attn_out_gates,
810841
cache = next(kv_caches, None)
811842
)
812843

844+
mem_input_layers.append(attn_out)
845+
813846
value_residual = default(value_residual, values)
814847

848+
x = add_residual(attn_out)
849+
815850
# caches
816851

817852
next_kv_caches.append(next_kv_cache)
818853
next_neural_mem_caches.append(next_neural_mem_cache)
819854

820855
# feedforward
821856

822-
x = ff(x)
857+
ff_in, add_ff_residual = ff_hyper_conn(x)
858+
859+
mem_input_layers.append(ff_in)
860+
861+
ff_out = ff(ff_in)
862+
863+
mem_input_layers.append(ff_out)
864+
865+
x = add_ff_residual(ff_out)
823866

824867
# taking care of cache first
825868
# for early return when processing long term mem tokens during inference

train_mac.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
4949
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
5050
NEURAL_MEM_WEIGHT_RESIDUAL = True # learning to accept contributions from the weights of the previous neural mem layer brings about significant improvements. this was improvised and not in the paper, but inspired by the value residual learning free lunch paper
51+
NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW = True # will allow the neural memory to select what layers from which to derive queries / keys / values, effectively allowing it to graft itself to the transformer in any way to be beneficial. this is to address an issue from a phd student who noted that the mem network is learning nothing more than wk @ wv. this also generalizes all possible ways to connect the neural memory to a transformer, a sort of NAS
5152

5253
# experiment related
5354

@@ -107,6 +108,7 @@ def decode_tokens(tokens):
107108
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
108109
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
109110
neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
111+
neural_memory_qkv_receives_diff_views = NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW,
110112
use_flex_attn = USE_FLEX_ATTN,
111113
sliding_window_attn = SLIDING_WINDOWS,
112114
neural_memory_model = neural_memory_model,

0 commit comments

Comments
 (0)