Skip to content

Commit 73b823b

Browse files
Dont quantize the current token for attention
Pull Request resolved: #5715 ghstack-source-id: 255730816 @exported-using-ghexport Differential Revision: [D63497872](https://our.internmc.facebook.com/intern/diff/D63497872/) Co-authored-by: Kimish Patel <[email protected]>
1 parent 722de99 commit 73b823b

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,26 @@ def update(self, input_pos, k_val, v_val):
188188
self.quantized_cache_dtype,
189189
self.cache_fp_type,
190190
)
191+
192+
if self.is_transposed:
193+
if self.enable_dynamic_shape:
194+
start_pos = input_pos[0].item()
195+
torch._check_is_size(start_pos)
196+
dim_to_slice = 2 if self.is_transposed else 1
197+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
198+
seq_length = k_val.size(dim_to_slice)
199+
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
200+
narrowed_k.copy_(k_val)
201+
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
202+
narrowed_v.copy_(v_val)
203+
else:
204+
k_out[:, :, input_pos] = k_val
205+
v_out[:, :, input_pos] = v_val
206+
else:
207+
start_pos = input_pos[0].item()
208+
_ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos)
209+
_ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos)
210+
191211
return k_out, v_out
192212

193213
@classmethod

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,6 @@ def test_simple(self, is_dynamic_shape=False):
6666
torch.testing.assert_close(
6767
float_out,
6868
quantized_out,
69-
# had to adjust rtol because switching to using custom_sdpa means we
70-
# will use dequantized k and v instead of original k and v
71-
# this leads to larger differences in the output.
72-
# subsequent diff in the stack will address this issue.
73-
rtol=1e-01,
74-
atol=1e-03,
7569
)
7670

7771
input_pos = torch.tensor([3], dtype=torch.int64)

0 commit comments

Comments
 (0)