Skip to content

Commit 99c643e

Browse files
committed
[ExecuTorch] Some updated to kv cache
Update kv cache impl to consider untransposed cache Differential Revision: [D62301843](https://our.internmc.facebook.com/intern/diff/D62301843/) [ghstack-poisoned]
1 parent b2517d6 commit 99c643e

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(
151151
):
152152
super().__init__()
153153
self.max_seq_length = max_seq_length
154+
self.is_tranposed = transpose_cache
154155
if transpose_cache:
155156
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
156157
else:
@@ -173,28 +174,34 @@ def update(
173174
) -> Tuple[torch.Tensor, torch.Tensor]:
174175
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
175176
if self.enable_dynamic_shape:
176-
start_pos = input_pos[-1].item()
177+
start_pos = input_pos[0].item()
177178
torch._check_is_size(start_pos)
178179
torch._check(start_pos < self.max_seq_length)
179-
seq_length = k_val.size(2)
180+
dim_to_slice = 2 if self.transpose_cache else 1
181+
seq_length = k_val.size(dim_to_slice)
180182
# Replace the entry in the cache for this token
181183
# The following lines are equivalent to:
182184
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
183185
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
186+
# when dim_to_slice is 1
184187
# We use .narrow() here to make the compiler happy
185188
# pyre-ignore: Incompatible parameter type [6]
186-
narrowed_k = self.k_cache.narrow(2, start_pos, seq_length)
189+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
187190
# pyre-ignore: Incompatible parameter type [6]
188-
narrowed_v = self.v_cache.narrow(2, start_pos, seq_length)
191+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
189192

190193
narrowed_k.copy_(k_val)
191194
narrowed_v.copy_(v_val)
192195
return self.k_cache, self.v_cache
193196
else:
194197
k_out = self.k_cache
195198
v_out = self.v_cache
196-
k_out[:, :, input_pos] = k_val
197-
v_out[:, :, input_pos] = v_val
199+
if self.transpose_cache:
200+
k_out[:, :, input_pos] = k_val
201+
v_out[:, :, input_pos] = v_val
202+
else:
203+
k_out[:, input_pos] = k_val
204+
v_out[:, input_pos] = v_val
198205

199206
return k_out, v_out
200207

0 commit comments

Comments
 (0)