@@ -151,6 +151,7 @@ def __init__(
151
151
):
152
152
super ().__init__ ()
153
153
self .max_seq_length = max_seq_length
154
+ self .is_tranposed = transpose_cache
154
155
if transpose_cache :
155
156
cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
156
157
else :
@@ -173,28 +174,34 @@ def update(
173
174
) -> Tuple [torch .Tensor , torch .Tensor ]:
174
175
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
175
176
if self .enable_dynamic_shape :
176
- start_pos = input_pos [- 1 ].item ()
177
+ start_pos = input_pos [0 ].item ()
177
178
torch ._check_is_size (start_pos )
178
179
torch ._check (start_pos < self .max_seq_length )
179
- seq_length = k_val .size (2 )
180
+ dim_to_slice = 2 if self .transpose_cache else 1
181
+ seq_length = k_val .size (dim_to_slice )
180
182
# Replace the entry in the cache for this token
181
183
# The following lines are equivalent to:
182
184
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
183
185
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
186
+ # when dim_to_slice is 1
184
187
# We use .narrow() here to make the compiler happy
185
188
# pyre-ignore: Incompatible parameter type [6]
186
- narrowed_k = self .k_cache .narrow (2 , start_pos , seq_length )
189
+ narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
187
190
# pyre-ignore: Incompatible parameter type [6]
188
- narrowed_v = self .v_cache .narrow (2 , start_pos , seq_length )
191
+ narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
189
192
190
193
narrowed_k .copy_ (k_val )
191
194
narrowed_v .copy_ (v_val )
192
195
return self .k_cache , self .v_cache
193
196
else :
194
197
k_out = self .k_cache
195
198
v_out = self .v_cache
196
- k_out [:, :, input_pos ] = k_val
197
- v_out [:, :, input_pos ] = v_val
199
+ if self .transpose_cache :
200
+ k_out [:, :, input_pos ] = k_val
201
+ v_out [:, :, input_pos ] = v_val
202
+ else :
203
+ k_out [:, input_pos ] = k_val
204
+ v_out [:, input_pos ] = v_val
198
205
199
206
return k_out , v_out
200
207
0 commit comments