Skip to content

Commit bd2c41e

Browse files
committed
[Executorch][llama] Update SDPA op to use quantized kv cache
Using quantized kv cache, we cannot rely on sdpa to update the original case. SO we insert cache update op Differential Revision: [D62301841](https://our.internmc.facebook.com/intern/diff/D62301841/) [ghstack-poisoned]
1 parent ede4406 commit bd2c41e

File tree

6 files changed

+198
-43
lines changed

6 files changed

+198
-43
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
474474
transforms.append(replace_sdpa_with_custom_op)
475475

476476
if args.quantize_kv_cache:
477-
assert (
478-
args.use_kv_cache and not args.use_sdpa_with_kv_cache
479-
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
477+
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
480478
transforms.append(replace_kv_cache_with_quantized_kv_cache)
481479

482480
if args.use_kv_cache:

examples/models/llama2/source_transformation/TARGETS

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,45 @@ runtime.python_library(
1515
],
1616
)
1717

18+
runtime.python_library(
19+
name = "sdpa",
20+
srcs = [
21+
"sdpa.py",
22+
],
23+
_is_external_target = True,
24+
base_module = "executorch.examples.models.llama2.source_transformation",
25+
visibility = ["//executorch/..."],
26+
deps = [
27+
"//caffe2:torch",
28+
],
29+
)
30+
1831
runtime.python_test(
1932
name = "quantized_kv_cache_test",
2033
srcs = [
2134
"test_quantized_kv_cache.py",
2235
],
36+
preload_deps = [
37+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
38+
],
39+
deps = [
40+
":quantized_kv_cache",
41+
"//caffe2:torch",
42+
"//executorch/examples/models/llama2:llama_transformer",
43+
],
44+
)
45+
46+
runtime.python_test(
47+
name = "quantized_sdpa_with_kv_cache_test",
48+
srcs = [
49+
"test_sdpa_with_quantized_kv_cache.py",
50+
],
51+
preload_deps = [
52+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
53+
],
2354
deps = [
2455
":quantized_kv_cache",
56+
":sdpa",
2557
"//caffe2:torch",
2658
"//executorch/examples/models/llama2:llama_transformer",
2759
],

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
raise ValueError(
4848
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
4949
)
50+
5051
# For now supporting int8 only
5152
self.quantized_cache_dtype = torch.int8
5253
self.cache_fp_type = torch.float32
@@ -104,51 +105,78 @@ def update(self, input_pos, k_val, v_val):
104105
torch.int8,
105106
)
106107

107-
if self.enable_dynamic_shape:
108-
start_pos = input_pos[0].item()
109-
torch._check_is_size(start_pos)
110-
if self.is_transposed:
111-
dim_to_slice = 2
108+
if self.is_transposed:
109+
# We cannot use update_cache op at the moment
110+
# if the cache is transposed
111+
# Also note that we shold not need separate paths
112+
# for dynamic shape vs !
113+
# Only reason it is done this way is to accommodate
114+
# for lowering pains of backends that work better
115+
# with index_put op.
116+
if self.enable_dynamic_shape:
117+
start_pos = input_pos[0].item()
118+
torch._check_is_size(start_pos)
119+
if self.is_transposed:
120+
dim_to_slice = 2
121+
else:
122+
dim_to_slice = 1
123+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
124+
seq_length = k_val.size(dim_to_slice)
125+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
126+
narrowed_k_scales = self.k_cache_scales.narrow(
127+
dim_to_slice, start_pos, seq_length
128+
)
129+
narrowed_k_zp = self.k_cache_zero_points.narrow(
130+
dim_to_slice, start_pos, seq_length
131+
)
132+
narrowed_k.copy_(quantized_k_val)
133+
narrowed_k_scales.copy_(k_scales)
134+
narrowed_k_zp.copy_(k_zero_points)
135+
# pyre-ignore: Incompatible parameter type [6]
136+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
137+
narrowed_v_scales = self.v_cache_scales.narrow(
138+
dim_to_slice, start_pos, seq_length
139+
)
140+
narrowed_v_zp = self.v_cache_zero_points.narrow(
141+
dim_to_slice, start_pos, seq_length
142+
)
143+
narrowed_v.copy_(quantized_v_val)
144+
narrowed_v_scales.copy_(v_scales)
145+
narrowed_v_zp.copy_(v_zero_points)
112146
else:
113-
dim_to_slice = 1
114-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
115-
seq_length = k_val.size(dim_to_slice)
116-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
117-
narrowed_k_scales = self.k_cache_scales.narrow(
118-
dim_to_slice, start_pos, seq_length
119-
)
120-
narrowed_k_zp = self.k_cache_zero_points.narrow(
121-
dim_to_slice, start_pos, seq_length
122-
)
123-
narrowed_k.copy_(quantized_k_val)
124-
narrowed_k_scales.copy_(k_scales)
125-
narrowed_k_zp.copy_(k_zero_points)
126-
# pyre-ignore: Incompatible parameter type [6]
127-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
128-
narrowed_v_scales = self.v_cache_scales.narrow(
129-
dim_to_slice, start_pos, seq_length
130-
)
131-
narrowed_v_zp = self.v_cache_zero_points.narrow(
132-
dim_to_slice, start_pos, seq_length
133-
)
134-
narrowed_v.copy_(quantized_v_val)
135-
narrowed_v_scales.copy_(v_scales)
136-
narrowed_v_zp.copy_(v_zero_points)
137-
else:
138-
if self.is_transposed:
139147
self.k_cache[:, :, input_pos] = quantized_k_val
140148
self.k_cache_scales[:, :, input_pos] = k_scales
141149
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
142150
self.v_cache[:, :, input_pos] = quantized_v_val
143151
self.v_cache_scales[:, :, input_pos] = v_scales
144152
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
145-
else:
146-
self.k_cache[:, input_pos] = quantized_k_val
147-
self.k_cache_scales[:, input_pos] = k_scales
148-
self.k_cache_zero_points[:, input_pos] = k_zero_points
149-
self.v_cache[:, input_pos] = quantized_v_val
150-
self.v_cache_scales[:, input_pos] = v_scales
151-
self.v_cache_zero_points[:, input_pos] = v_zero_points
153+
else:
154+
# Right now using custom ops on this path.
155+
# In future we can update custom op to handle transposed cache
156+
# as well.
157+
# Note that we may have to revert this change if other ET
158+
# backends such as QNN want to use quantized cache, with dynamic shape,
159+
# instead of quantizing on their own.
160+
# But until this opting for code simplicity
161+
start_pos = input_pos[0].item()
162+
_ = torch.ops.llama.update_quantized_cache(
163+
quantized_k_val, self.k_cache, start_pos
164+
)
165+
_ = torch.ops.llama.update_quantized_cache(
166+
k_scales, self.k_cache_scales, start_pos
167+
)
168+
_ = torch.ops.llama.update_quantized_cache(
169+
k_zero_points, self.k_cache_zero_points, start_pos
170+
)
171+
_ = torch.ops.llama.update_quantized_cache(
172+
quantized_v_val, self.v_cache, start_pos
173+
)
174+
_ = torch.ops.llama.update_quantized_cache(
175+
v_scales, self.v_cache_scales, start_pos
176+
)
177+
_ = torch.ops.llama.update_quantized_cache(
178+
v_zero_points, self.v_cache_zero_points, start_pos
179+
)
152180

153181
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
154182
self.k_cache,

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import torch
1515

1616
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
17+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
18+
QuantizedKVCache,
19+
)
1720

1821

1922
class SDPACustom(torch.nn.Module):
@@ -36,12 +39,26 @@ def forward(
3639
seqlen,
3740
mask,
3841
):
42+
k_cache = self.kv_cache.k_cache
43+
v_cache = self.kv_cache.v_cache
44+
if isinstance(self.kv_cache, QuantizedKVCache):
45+
# updated quantize cache, scale and zero points
46+
# returns dequantized kv cache
47+
# Not most optimal. Optimizations to follow next
48+
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
49+
# Note that this path will still inplace mutate the k_cache, v_cache.
50+
# WHen we are not using quantized kv cache, this will just mutate
51+
# the original kv cache.
52+
# When we aer using quantized kv cache, this will mutate
53+
# k_cache, v_cache that is returned from cache update operation.
54+
# This operation just dequantized thee cache and returns that.
55+
# Future diffs will optimize this
3956
output = torch.ops.llama.sdpa_with_kv_cache(
4057
q,
4158
k,
4259
v,
43-
self.kv_cache.k_cache,
44-
self.kv_cache.v_cache,
60+
k_cache,
61+
v_cache,
4562
input_pos[-1].item(),
4663
seqlen,
4764
None, # Attention mask
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import unittest
2+
3+
import torch
4+
5+
from executorch.examples.models.llama2.llama_transformer import KVCache
6+
7+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
8+
QuantizedCacheType,
9+
QuantizedKVCache,
10+
)
11+
12+
from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom
13+
14+
15+
class SDPAWithQuantizedKVCacheTest(unittest.TestCase):
16+
17+
def _init_cache(self):
18+
self.kv_cache = KVCache(
19+
self.max_batch_size,
20+
self.max_seq_len,
21+
self.n_kv_heads,
22+
self.head_dim,
23+
False,
24+
self.enable_dynamic_shape,
25+
dtype=self.dtype,
26+
)
27+
self.quantized_kv_cache = QuantizedKVCache.from_float(
28+
self.kv_cache, QuantizedCacheType.AffineAsymmetric
29+
)
30+
31+
def _init_kv(self):
32+
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
33+
q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
34+
q = torch.rand(q_shape, dtype=self.dtype)
35+
k = torch.rand(kv_shape, dtype=self.dtype)
36+
v = torch.rand(kv_shape, dtype=self.dtype)
37+
return q, k, v
38+
39+
def setUp(self):
40+
torch.manual_seed(42)
41+
self.max_batch_size = 1
42+
self.max_seq_len = 5
43+
self.n_kv_heads = 4
44+
self.n_heads = 8
45+
self.head_dim = 17
46+
self.dim = self.n_heads * self.head_dim
47+
self.enable_dynamic_shape = False
48+
self.dtype = torch.float32
49+
50+
def test_simple(self, is_dynamic_shape=False):
51+
self.enable_dynamic_shape = is_dynamic_shape
52+
input_pos = torch.tensor([0], dtype=torch.int64)
53+
self.seq_len = 3
54+
self._init_cache()
55+
q, k, v = self._init_kv()
56+
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
57+
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
58+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
59+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
60+
self.assertTrue(
61+
torch.allclose(
62+
float_out,
63+
quantized_out,
64+
)
65+
)
66+
67+
input_pos = torch.tensor([3], dtype=torch.int64)
68+
self.seq_len = 1
69+
q, k, v = self._init_kv()
70+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
71+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
72+
self.assertTrue(
73+
torch.allclose(
74+
float_out,
75+
quantized_out,
76+
rtol=1e-03,
77+
atol=1e-03,
78+
)
79+
)

extension/llm/custom_ops/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_common_targets():
2020
"op_sdpa.h",
2121
],
2222
exported_deps = [
23+
":update_quantized_cache",
2324
"//executorch/runtime/kernel:kernel_includes",
2425
"//executorch/kernels/portable/cpu:scalar_utils",
2526
"//executorch/kernels/optimized:libblas{}".format(mkl_dep),

0 commit comments

Comments
 (0)