Skip to content

Commit f237471

Browse files
committed
Dont quantize the current token for attention
Pull Request resolved: #5715 ghstack-source-id: 245718144 @exported-using-ghexport Differential Revision: [D63497872](https://our.internmc.facebook.com/intern/diff/D63497872/)
1 parent fb768ce commit f237471

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,27 @@ def update(self, input_pos, k_val, v_val):
188188
self.quantized_cache_dtype,
189189
self.cache_fp_type,
190190
)
191+
192+
if self.is_transposed:
193+
if self.enable_dynamic_shape:
194+
start_pos = input_pos[0].item()
195+
torch._check_is_size(start_pos)
196+
dim_to_slice = 2 if self.is_transposed else 1
197+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
198+
seq_length = k_val.size(dim_to_slice)
199+
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
200+
narrowed_k.copy_(k_val)
201+
# pyre-ignore: Incompatible parameter type [6]
202+
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
203+
narrowed_v.copy_(v_val)
204+
else:
205+
k_out[:, :, input_pos] = k_val
206+
v_out[:, :, input_pos] = v_val
207+
else:
208+
start_pos = input_pos[0].item()
209+
_ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos)
210+
_ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos)
211+
191212
return k_out, v_out
192213

193214
@classmethod

0 commit comments

Comments
 (0)