diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 6d92a45e800..306c7380ecf 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -188,6 +188,26 @@ def update(self, input_pos, k_val, v_val): self.quantized_cache_dtype, self.cache_fp_type, ) + + if self.is_transposed: + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + dim_to_slice = 2 if self.is_transposed else 1 + torch._check(start_pos < self.k_cache.size(dim_to_slice)) + seq_length = k_val.size(dim_to_slice) + narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length) + narrowed_k.copy_(k_val) + narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length) + narrowed_v.copy_(v_val) + else: + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + else: + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_quantized_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_quantized_cache(v_val, v_out, start_pos) + return k_out, v_out @classmethod diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 65c6678ab25..21952d8c211 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -66,12 +66,6 @@ def test_simple(self, is_dynamic_shape=False): torch.testing.assert_close( float_out, quantized_out, - # had to adjust rtol because switching to using custom_sdpa means we - # will use dequantized k and v instead of original k and v - # this leads to larger differences in the output. - # subsequent diff in the stack will address this issue. - rtol=1e-01, - atol=1e-03, ) input_pos = torch.tensor([3], dtype=torch.int64)