diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 90a8ec1e80d..ba09c8b4e78 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -474,9 +474,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: - assert ( - args.use_kv_cache and not args.use_sdpa_with_kv_cache - ), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False" + assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) if args.use_kv_cache: diff --git a/examples/models/llama2/source_transformation/TARGETS b/examples/models/llama2/source_transformation/TARGETS index 71687b8e1ff..0ddf8f19456 100644 --- a/examples/models/llama2/source_transformation/TARGETS +++ b/examples/models/llama2/source_transformation/TARGETS @@ -15,13 +15,45 @@ runtime.python_library( ], ) +runtime.python_library( + name = "sdpa", + srcs = [ + "sdpa.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama2.source_transformation", + visibility = ["//executorch/..."], + deps = [ + "//caffe2:torch", + ], +) + runtime.python_test( name = "quantized_kv_cache_test", srcs = [ "test_quantized_kv_cache.py", ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + ], + deps = [ + ":quantized_kv_cache", + "//caffe2:torch", + "//executorch/examples/models/llama2:llama_transformer", + ], +) + +runtime.python_test( + name = "quantized_sdpa_with_kv_cache_test", + srcs = [ + "test_sdpa_with_quantized_kv_cache.py", + ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + ], deps = [ ":quantized_kv_cache", + ":sdpa", "//caffe2:torch", "//executorch/examples/models/llama2:llama_transformer", ], diff --git a/examples/models/llama2/source_transformation/quantized_kv_cache.py b/examples/models/llama2/source_transformation/quantized_kv_cache.py index c46f4696252..edb5973ba88 100644 --- a/examples/models/llama2/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama2/source_transformation/quantized_kv_cache.py @@ -47,6 +47,7 @@ def __init__( raise ValueError( f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" ) + # For now supporting int8 only self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 @@ -104,51 +105,78 @@ def update(self, input_pos, k_val, v_val): torch.int8, ) - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - if self.is_transposed: - dim_to_slice = 2 + if self.is_transposed: + # We cannot use update_cache op at the moment + # if the cache is transposed + # Also note that we shold not need separate paths + # for dynamic shape vs ! + # Only reason it is done this way is to accommodate + # for lowering pains of backends that work better + # with index_put op. + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + if self.is_transposed: + dim_to_slice = 2 + else: + dim_to_slice = 1 + torch._check(start_pos < self.k_cache.size(dim_to_slice)) + seq_length = k_val.size(dim_to_slice) + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_k_scales = self.k_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k_zp = self.k_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_k.copy_(quantized_k_val) + narrowed_k_scales.copy_(k_scales) + narrowed_k_zp.copy_(k_zero_points) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) + narrowed_v_scales = self.v_cache_scales.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v_zp = self.v_cache_zero_points.narrow( + dim_to_slice, start_pos, seq_length + ) + narrowed_v.copy_(quantized_v_val) + narrowed_v_scales.copy_(v_scales) + narrowed_v_zp.copy_(v_zero_points) else: - dim_to_slice = 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k_scales = self.k_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k_zp = self.k_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k.copy_(quantized_k_val) - narrowed_k_scales.copy_(k_scales) - narrowed_k_zp.copy_(k_zero_points) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v_scales = self.v_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v_zp = self.v_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v.copy_(quantized_v_val) - narrowed_v_scales.copy_(v_scales) - narrowed_v_zp.copy_(v_zero_points) - else: - if self.is_transposed: self.k_cache[:, :, input_pos] = quantized_k_val self.k_cache_scales[:, :, input_pos] = k_scales self.k_cache_zero_points[:, :, input_pos] = k_zero_points self.v_cache[:, :, input_pos] = quantized_v_val self.v_cache_scales[:, :, input_pos] = v_scales self.v_cache_zero_points[:, :, input_pos] = v_zero_points - else: - self.k_cache[:, input_pos] = quantized_k_val - self.k_cache_scales[:, input_pos] = k_scales - self.k_cache_zero_points[:, input_pos] = k_zero_points - self.v_cache[:, input_pos] = quantized_v_val - self.v_cache_scales[:, input_pos] = v_scales - self.v_cache_zero_points[:, input_pos] = v_zero_points + else: + # Right now using custom ops on this path. + # In future we can update custom op to handle transposed cache + # as well. + # Note that we may have to revert this change if other ET + # backends such as QNN want to use quantized cache, with dynamic shape, + # instead of quantizing on their own. + # But until this opting for code simplicity + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_quantized_cache( + quantized_k_val, self.k_cache, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + k_scales, self.k_cache_scales, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + k_zero_points, self.k_cache_zero_points, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + quantized_v_val, self.v_cache, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + v_scales, self.v_cache_scales, start_pos + ) + _ = torch.ops.llama.update_quantized_cache( + v_zero_points, self.v_cache_zero_points, start_pos + ) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 0d2e4852e94..263a98a66b3 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -14,6 +14,9 @@ import torch from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA +from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import ( + QuantizedKVCache, +) class SDPACustom(torch.nn.Module): @@ -36,12 +39,26 @@ def forward( seqlen, mask, ): + k_cache = self.kv_cache.k_cache + v_cache = self.kv_cache.v_cache + if isinstance(self.kv_cache, QuantizedKVCache): + # updated quantize cache, scale and zero points + # returns dequantized kv cache + # Not most optimal. Optimizations to follow next + k_cache, v_cache = self.kv_cache.update(input_pos, k, v) + # Note that this path will still inplace mutate the k_cache, v_cache. + # WHen we are not using quantized kv cache, this will just mutate + # the original kv cache. + # When we aer using quantized kv cache, this will mutate + # k_cache, v_cache that is returned from cache update operation. + # This operation just dequantized thee cache and returns that. + # Future diffs will optimize this output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, + k_cache, + v_cache, input_pos[-1].item(), seqlen, None, # Attention mask diff --git a/examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py new file mode 100644 index 00000000000..4d2cbfaf4d0 --- /dev/null +++ b/examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -0,0 +1,79 @@ +import unittest + +import torch + +from executorch.examples.models.llama2.llama_transformer import KVCache + +from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import ( + QuantizedCacheType, + QuantizedKVCache, +) + +from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom + + +class SDPAWithQuantizedKVCacheTest(unittest.TestCase): + + def _init_cache(self): + self.kv_cache = KVCache( + self.max_batch_size, + self.max_seq_len, + self.n_kv_heads, + self.head_dim, + False, + self.enable_dynamic_shape, + dtype=self.dtype, + ) + self.quantized_kv_cache = QuantizedKVCache.from_float( + self.kv_cache, QuantizedCacheType.AffineAsymmetric + ) + + def _init_kv(self): + kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) + q_shape = (1, self.seq_len, self.n_heads, self.head_dim) + q = torch.rand(q_shape, dtype=self.dtype) + k = torch.rand(kv_shape, dtype=self.dtype) + v = torch.rand(kv_shape, dtype=self.dtype) + return q, k, v + + def setUp(self): + torch.manual_seed(42) + self.max_batch_size = 1 + self.max_seq_len = 5 + self.n_kv_heads = 4 + self.n_heads = 8 + self.head_dim = 17 + self.dim = self.n_heads * self.head_dim + self.enable_dynamic_shape = False + self.dtype = torch.float32 + + def test_simple(self, is_dynamic_shape=False): + self.enable_dynamic_shape = is_dynamic_shape + input_pos = torch.tensor([0], dtype=torch.int64) + self.seq_len = 3 + self._init_cache() + q, k, v = self._init_kv() + self.float_sdpa = SDPACustom(self.kv_cache, self.dim) + self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim) + float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + self.assertTrue( + torch.allclose( + float_out, + quantized_out, + ) + ) + + input_pos = torch.tensor([3], dtype=torch.int64) + self.seq_len = 1 + q, k, v = self._init_kv() + float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) + self.assertTrue( + torch.allclose( + float_out, + quantized_out, + rtol=1e-03, + atol=1e-03, + ) + ) diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 503e4a0c7bd..a5bf280d76f 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -20,6 +20,7 @@ def define_common_targets(): "op_sdpa.h", ], exported_deps = [ + ":update_quantized_cache", "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/optimized:libblas{}".format(mkl_dep),