Skip to content

Commit dd678eb

Browse files
committed
[ExecuTorch] Add quantized kv cache to llama
This diff adds - quantized kv cache imlementation and apply corresponding source transforms - add support for quant/dequant per token in quantized kernels Differential Revision: [D62301844](https://our.internmc.facebook.com/intern/diff/D62301844/) [ghstack-poisoned]
1 parent c01e703 commit dd678eb

File tree

15 files changed

+1106
-35
lines changed

15 files changed

+1106
-35
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ runtime.python_library(
7373
"source_transformation/apply_spin_quant_r1_r2.py",
7474
"source_transformation/prune_output.py",
7575
"source_transformation/quantize.py",
76+
"source_transformation/quantized_kv_cache.py",
7677
"source_transformation/rms_norm.py",
7778
"source_transformation/rope.py",
7879
"source_transformation/sdpa.py",

examples/models/llama2/export_llama_lib.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@
5353
get_quant_embedding_transform,
5454
get_quant_weight_transform,
5555
)
56+
from .source_transformation.quantized_kv_cache import (
57+
replace_kv_cache_with_quantized_kv_cache,
58+
)
5659
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
60+
5761
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5862
from .source_transformation.sdpa import (
5963
replace_causal_mask,
@@ -206,6 +210,12 @@ def build_args_parser() -> argparse.ArgumentParser:
206210
action="store_true",
207211
help="Whether or not to export a model using kv cache",
208212
)
213+
parser.add_argument(
214+
"--quantize_kv_cache",
215+
default=False,
216+
action="store_true",
217+
help="Whether or not to export a model using int8 per token quantized kv cache",
218+
)
209219
parser.add_argument(
210220
"--num_sharding",
211221
type=int,
@@ -428,7 +438,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
428438
429439
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
430440
"""
431-
432441
# load model from checkpoint and params.json
433442
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
434443
checkpoint_dir = (
@@ -806,6 +815,12 @@ def _get_source_transforms( # noqa
806815
if args.use_sdpa_with_kv_cache:
807816
transforms.append(replace_sdpa_with_custom_op)
808817

818+
if args.quantize_kv_cache:
819+
assert (
820+
args.use_kv_cache and not args.use_sdpa_with_kv_cache
821+
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
822+
transforms.append(replace_kv_cache_with_quantized_kv_cache)
823+
809824
if args.use_kv_cache:
810825
if args.qnn:
811826
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "quantized_kv_cache",
7+
srcs = [
8+
"quantized_kv_cache.py",
9+
],
10+
_is_external_target = True,
11+
base_module = "executorch.examples.models.llama2.source_transformation",
12+
visibility = ["//executorch/..."],
13+
deps = [
14+
"//caffe2:torch",
15+
],
16+
)
17+
18+
runtime.python_test(
19+
name = "quantized_kv_cache_test",
20+
srcs = [
21+
"test_quantized_kv_cache.py",
22+
],
23+
deps = [
24+
":quantized_kv_cache",
25+
"//caffe2:torch",
26+
"//executorch/examples/models/llama2:llama_transformer",
27+
],
28+
)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from enum import Enum
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.examples.models.llama2.llama_transformer import KVCache
13+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
14+
15+
16+
"""
17+
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
18+
"""
19+
20+
21+
# Doesnt have to abide by affine quantizaiton laws
22+
# However, if we do implement quantized sdpa, then this might be handy
23+
class QuantizedCacheType(Enum):
24+
AffineSymmetric = 0
25+
AffineAsymmetric = 1
26+
AffineSymmetricGroupWise = 2
27+
AffineAsymmetricGroupWise = 3
28+
29+
30+
class QuantizedKVCache(nn.Module):
31+
def __init__(
32+
self,
33+
max_batch_size,
34+
max_seq_length,
35+
n_heads,
36+
head_dim,
37+
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
38+
tranposed=False,
39+
enable_dynamic_shape=False,
40+
):
41+
super().__init__()
42+
if cache_type not in (
43+
QuantizedCacheType.AffineSymmetric,
44+
QuantizedCacheType.AffineAsymmetric,
45+
):
46+
47+
raise ValueError(
48+
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
49+
)
50+
# For now supporting int8 only
51+
self.quantized_cache_dtype = torch.int8
52+
self.cache_fp_type = torch.float32
53+
self.is_transposed = tranposed
54+
self.enable_dynamic_shape = enable_dynamic_shape
55+
if self.is_transposed:
56+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
57+
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
58+
else:
59+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
60+
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
61+
self.register_buffer(
62+
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
63+
)
64+
self.register_buffer(
65+
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
66+
)
67+
self.register_buffer(
68+
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
69+
)
70+
self.register_buffer(
71+
"v_cache_scales", torch.ones(scale_shape, dtype=torch.double)
72+
)
73+
if cache_type == QuantizedCacheType.AffineAsymmetric:
74+
self.register_buffer(
75+
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
76+
)
77+
self.register_buffer(
78+
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
79+
)
80+
81+
def _quantize(self, value):
82+
scales, zero_points = (
83+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
84+
value, self.quantized_cache_dtype
85+
)
86+
)
87+
quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
88+
value,
89+
scales,
90+
zero_points,
91+
torch.iinfo(self.quantized_cache_dtype).min,
92+
torch.iinfo(self.quantized_cache_dtype).max,
93+
self.quantized_cache_dtype,
94+
)
95+
return quantized_value, scales, zero_points
96+
97+
def update(self, input_pos, k_val, v_val):
98+
# quantize current k_val and store it in the cache
99+
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
100+
101+
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
102+
103+
if self.enable_dynamic_shape:
104+
start_pos = input_pos[0].item()
105+
torch._check_is_size(start_pos)
106+
dim_to_slice = 2 if self.is_transposed else 1
107+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
108+
seq_length = k_val.size(dim_to_slice)
109+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
110+
narrowed_k_scales = self.k_cache_scales.narrow(
111+
dim_to_slice, start_pos, seq_length
112+
)
113+
narrowed_k_zp = self.k_cache_zero_points.narrow(
114+
dim_to_slice, start_pos, seq_length
115+
)
116+
narrowed_k.copy_(quantized_k_val)
117+
narrowed_k_scales.copy_(k_scales)
118+
narrowed_k_zp.copy_(k_zero_points)
119+
# pyre-ignore: Incompatible parameter type [6]
120+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
121+
narrowed_v_scales = self.v_cache_scales.narrow(
122+
dim_to_slice, start_pos, seq_length
123+
)
124+
narrowed_v_zp = self.v_cache_zero_points.narrow(
125+
dim_to_slice, start_pos, seq_length
126+
)
127+
narrowed_v.copy_(quantized_v_val)
128+
narrowed_v_scales.copy_(v_scales)
129+
narrowed_v_zp.copy_(v_zero_points)
130+
else:
131+
if self.is_transposed:
132+
self.k_cache[:, :, input_pos] = quantized_k_val
133+
self.k_cache_scales[:, :, input_pos] = k_scales
134+
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
135+
self.v_cache[:, :, input_pos] = quantized_v_val
136+
self.v_cache_scales[:, :, input_pos] = v_scales
137+
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
138+
else:
139+
self.k_cache[:, input_pos] = quantized_k_val
140+
self.k_cache_scales[:, input_pos] = k_scales
141+
self.k_cache_zero_points[:, input_pos] = k_zero_points
142+
self.v_cache[:, input_pos] = quantized_v_val
143+
self.v_cache_scales[:, input_pos] = v_scales
144+
self.v_cache_zero_points[:, input_pos] = v_zero_points
145+
146+
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
147+
self.k_cache,
148+
self.k_cache_scales,
149+
self.k_cache_zero_points,
150+
torch.iinfo(self.quantized_cache_dtype).min,
151+
torch.iinfo(self.quantized_cache_dtype).max,
152+
self.quantized_cache_dtype,
153+
self.cache_fp_type,
154+
)
155+
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
156+
self.v_cache,
157+
self.v_cache_scales,
158+
self.v_cache_zero_points,
159+
torch.iinfo(self.quantized_cache_dtype).min,
160+
torch.iinfo(self.quantized_cache_dtype).max,
161+
self.quantized_cache_dtype,
162+
self.cache_fp_type,
163+
)
164+
return k_out, v_out
165+
166+
@classmethod
167+
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
168+
cache_shape = kv_cache.k_cache.shape
169+
if kv_cache.is_tranposed:
170+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
171+
else:
172+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
173+
return cls(
174+
max_batch_size,
175+
max_seq_length,
176+
n_heads,
177+
head_dim,
178+
cache_type,
179+
kv_cache.is_tranposed,
180+
kv_cache.enable_dynamic_shape,
181+
)
182+
183+
184+
def replace_kv_cache_with_quantized_kv_cache(module):
185+
logging.warning(
186+
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
187+
)
188+
for name, child in module.named_children():
189+
if isinstance(child, KVCache):
190+
setattr(
191+
module,
192+
name,
193+
QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric),
194+
)
195+
else:
196+
replace_kv_cache_with_quantized_kv_cache(child)
197+
return module

0 commit comments

Comments
 (0)