Skip to content

[Executorch][llama] Update SDPA op to use quantized kv cache #5600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
assert (
args.use_kv_cache and not args.use_sdpa_with_kv_cache
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
transforms.append(replace_kv_cache_with_quantized_kv_cache)

if args.use_kv_cache:
Expand Down
32 changes: 32 additions & 0 deletions examples/models/llama2/source_transformation/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,45 @@ runtime.python_library(
],
)

runtime.python_library(
name = "sdpa",
srcs = [
"sdpa.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama2.source_transformation",
visibility = ["//executorch/..."],
deps = [
"//caffe2:torch",
],
)

runtime.python_test(
name = "quantized_kv_cache_test",
srcs = [
"test_quantized_kv_cache.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
],
deps = [
":quantized_kv_cache",
"//caffe2:torch",
"//executorch/examples/models/llama2:llama_transformer",
],
)

runtime.python_test(
name = "quantized_sdpa_with_kv_cache_test",
srcs = [
"test_sdpa_with_quantized_kv_cache.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
],
deps = [
":quantized_kv_cache",
":sdpa",
"//caffe2:torch",
"//executorch/examples/models/llama2:llama_transformer",
],
Expand Down
104 changes: 66 additions & 38 deletions examples/models/llama2/source_transformation/quantized_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
raise ValueError(
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
)

# For now supporting int8 only
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
Expand Down Expand Up @@ -104,51 +105,78 @@ def update(self, input_pos, k_val, v_val):
torch.int8,
)

if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
if self.is_transposed:
dim_to_slice = 2
if self.is_transposed:
# We cannot use update_cache op at the moment
# if the cache is transposed
# Also note that we shold not need separate paths
# for dynamic shape vs !
# Only reason it is done this way is to accommodate
# for lowering pains of backends that work better
# with index_put op.
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
if self.is_transposed:
dim_to_slice = 2
else:
dim_to_slice = 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k_scales = self.k_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k_zp = self.k_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k.copy_(quantized_k_val)
narrowed_k_scales.copy_(k_scales)
narrowed_k_zp.copy_(k_zero_points)
# pyre-ignore: Incompatible parameter type [6]
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v_scales = self.v_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v_zp = self.v_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v.copy_(quantized_v_val)
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
dim_to_slice = 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k_scales = self.k_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k_zp = self.k_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k.copy_(quantized_k_val)
narrowed_k_scales.copy_(k_scales)
narrowed_k_zp.copy_(k_zero_points)
# pyre-ignore: Incompatible parameter type [6]
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v_scales = self.v_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v_zp = self.v_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v.copy_(quantized_v_val)
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
if self.is_transposed:
self.k_cache[:, :, input_pos] = quantized_k_val
self.k_cache_scales[:, :, input_pos] = k_scales
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
self.v_cache[:, :, input_pos] = quantized_v_val
self.v_cache_scales[:, :, input_pos] = v_scales
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
else:
self.k_cache[:, input_pos] = quantized_k_val
self.k_cache_scales[:, input_pos] = k_scales
self.k_cache_zero_points[:, input_pos] = k_zero_points
self.v_cache[:, input_pos] = quantized_v_val
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points
else:
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_quantized_cache(
quantized_k_val, self.k_cache, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
k_scales, self.k_cache_scales, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
quantized_v_val, self.v_cache, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
v_scales, self.v_cache_scales, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand Down
21 changes: 19 additions & 2 deletions examples/models/llama2/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import torch

from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
QuantizedKVCache,
)


class SDPACustom(torch.nn.Module):
Expand All @@ -36,12 +39,26 @@ def forward(
seqlen,
mask,
):
k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if isinstance(self.kv_cache, QuantizedKVCache):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
# Note that this path will still inplace mutate the k_cache, v_cache.
# WHen we are not using quantized kv cache, this will just mutate
# the original kv cache.
# When we aer using quantized kv cache, this will mutate
# k_cache, v_cache that is returned from cache update operation.
# This operation just dequantized thee cache and returns that.
# Future diffs will optimize this
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
k_cache,
v_cache,
input_pos[-1].item(),
seqlen,
None, # Attention mask
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest

import torch

from executorch.examples.models.llama2.llama_transformer import KVCache

from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
QuantizedCacheType,
QuantizedKVCache,
)

from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom


class SDPAWithQuantizedKVCacheTest(unittest.TestCase):

def _init_cache(self):
self.kv_cache = KVCache(
self.max_batch_size,
self.max_seq_len,
self.n_kv_heads,
self.head_dim,
False,
self.enable_dynamic_shape,
dtype=self.dtype,
)
self.quantized_kv_cache = QuantizedKVCache.from_float(
self.kv_cache, QuantizedCacheType.AffineAsymmetric
)

def _init_kv(self):
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
q = torch.rand(q_shape, dtype=self.dtype)
k = torch.rand(kv_shape, dtype=self.dtype)
v = torch.rand(kv_shape, dtype=self.dtype)
return q, k, v

def setUp(self):
torch.manual_seed(42)
self.max_batch_size = 1
self.max_seq_len = 5
self.n_kv_heads = 4
self.n_heads = 8
self.head_dim = 17
self.dim = self.n_heads * self.head_dim
self.enable_dynamic_shape = False
self.dtype = torch.float32

def test_simple(self, is_dynamic_shape=False):
self.enable_dynamic_shape = is_dynamic_shape
input_pos = torch.tensor([0], dtype=torch.int64)
self.seq_len = 3
self._init_cache()
q, k, v = self._init_kv()
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
self.assertTrue(
torch.allclose(
float_out,
quantized_out,
)
)

input_pos = torch.tensor([3], dtype=torch.int64)
self.seq_len = 1
q, k, v = self._init_kv()
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
self.assertTrue(
torch.allclose(
float_out,
quantized_out,
rtol=1e-03,
atol=1e-03,
)
)
1 change: 1 addition & 0 deletions extension/llm/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def define_common_targets():
"op_sdpa.h",
],
exported_deps = [
":update_quantized_cache",
"//executorch/runtime/kernel:kernel_includes",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/kernels/optimized:libblas{}".format(mkl_dep),
Expand Down
Loading