Skip to content

Commit eb360ee

Browse files
authored
Add page attention manager and kvcache manager (#167)
* Add page attention manager and kvcache manager * adapt prefill update layer new api * use tensor indices * lint format
1 parent ee040a4 commit eb360ee

File tree

4 files changed

+517
-0
lines changed

4 files changed

+517
-0
lines changed

jetstream_pt/cache_manager.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,61 @@ def finalize(self):
654654
self.new_v_scaler,
655655
self.input_pos,
656656
)
657+
658+
659+
class PageKVCacheGenerate:
660+
"""Page attention kvache generator without quantization"""
661+
662+
def __init__(
663+
self,
664+
cache_k: torch.Tensor, # previous cache
665+
cache_v: torch.Tensor, # previous cache
666+
page_token_indices: torch.Tensor, # page and token indices for the cache
667+
sharding,
668+
env=None,
669+
):
670+
super().__init__()
671+
self.cache_k = cache_k
672+
self.cache_v = cache_v
673+
self.page_token_indices = page_token_indices
674+
self.sharding = sharding
675+
self.env = env
676+
677+
def update(self, key, value):
678+
"""Update kv cache"""
679+
keyj, valuej, page_token_indicesj = torchjax.from_torch(
680+
(key, value, self.page_token_indices)
681+
)
682+
683+
def _update(cache, x):
684+
x = x.squeeze(2).transpose((1, 0, 2))
685+
x = x[:, page_token_indicesj[2], :]
686+
head, _, page_size, dim = cache.shape
687+
selected_cache = cache[:, page_token_indicesj[0], :, :]
688+
selected_cache = selected_cache.reshape((head, -1, dim))
689+
690+
selected_cache = selected_cache.at[:, page_token_indicesj[1], :].set(x)
691+
selected_cache = selected_cache.reshape((head, -1, page_size, dim))
692+
693+
cache = cache.at[:, page_token_indicesj[0], :, :].set(selected_cache)
694+
return cache
695+
696+
# pylint: disable-next=all
697+
self.cache_k._elem = _update(self.cache_k._elem, keyj)
698+
# pylint: disable-next=all
699+
self.cache_k._elem = _update(self.cache_v._elem, valuej)
700+
return self.cache_k, self.cache_v
701+
702+
def state(self):
703+
"""Get kv cache state"""
704+
# pylint: disable-next=all
705+
return self.cache_k.jax(), self.cache_v.jax()
706+
707+
@classmethod
708+
def empty(cls, shape, device, bf16_enable, env):
709+
"""Create empty kv caches"""
710+
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
711+
k = jnp.zeros(shape, device=device, dtype=default_dtype)
712+
v = jnp.zeros(shape, device=device, dtype=default_dtype)
713+
k, v = torchjax.to_torch((k, v))
714+
return cls(k, v, None, device, env=env)
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import queue
2+
import functools
3+
from typing import List, Tuple
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import jax.sharding as jsharding
8+
9+
10+
class PageAttentionManager:
11+
"""Manages page blocks.
12+
13+
This manager maintains a main list of free page blocks, it support below features:
14+
1. Reseve pages for prefill insert and decode.
15+
2. Free pages resource for the slots after decode. Pages indices go to free list.
16+
3. Get pages indices meta data for all the slots.
17+
4. Transform and insert prefill caches to decode caches.
18+
"""
19+
20+
def __init__(
21+
self,
22+
batch_size: int,
23+
total_num_pages: int,
24+
page_size: int,
25+
max_pages_per_sequence: int,
26+
):
27+
self.unused_pages = queue.Queue()
28+
self.batch_size = batch_size
29+
self.page_indices = jnp.full(
30+
(batch_size, max_pages_per_sequence), -1, dtype=jnp.int32
31+
)
32+
self.lengths = jnp.zeros(batch_size, dtype=jnp.int32)
33+
self.page_size = page_size
34+
self.max_pages_per_sequence = max_pages_per_sequence
35+
for i in range(total_num_pages):
36+
self.unused_pages.put(i, block=False)
37+
38+
# pylint: disable-next=all
39+
def reserve_pages_insert(self, slot: int, seq_len: int) -> Tuple[int, list]:
40+
self.lengths = self.lengths.at[slot].set(seq_len)
41+
num_pages = seq_len // self.page_size
42+
if seq_len % self.page_size != 0:
43+
num_pages = num_pages + 1
44+
45+
indices = [self.unused_pages.get(block=False) for _ in range(num_pages)]
46+
self.page_indices = self.page_indices.at[slot, :num_pages].set(indices)
47+
return num_pages
48+
49+
# pylint: disable-next=all
50+
def reserve_pages_decode(self, slot: int, seq_len: int):
51+
if seq_len > 0 and seq_len % self.page_size == 0:
52+
index = self.unused_pages.get(block=False)
53+
num_pages = seq_len // self.page_size
54+
self.page_indices = self.page_indices.at[slot, num_pages].set(index)
55+
56+
# pylint: disable-next=all
57+
def prefill_cache_padding(
58+
self,
59+
caches: List[Tuple[jax.Array, jax.Array]],
60+
seq_len: int,
61+
num_pages: int,
62+
) -> List[Tuple[jax.Array, jax.Array]]:
63+
64+
pad_width = num_pages * self.page_size - seq_len
65+
if pad_width == 0:
66+
return caches
67+
68+
caches = [
69+
(self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width))
70+
for k, v in caches
71+
]
72+
return caches
73+
74+
def insert_prefill_cache(
75+
self,
76+
prefill_caches: List[Tuple[jax.Array, jax.Array]],
77+
decode_caches: List[Tuple[jax.Array, jax.Array]],
78+
slot: int,
79+
seq_len: int,
80+
sharding: jsharding.Sharding,
81+
) -> List[Tuple[jax.Array, jax.Array]]:
82+
"""Insert prefill caches to decode caches slot.
83+
84+
Args:
85+
prefill_caches: List of Tuple K, V. For each K, V:
86+
[batch_size, num_heads, seq_len, head_dim] jax.Array.
87+
decode_caches: List of Tuple K, V. For each K, V:
88+
[num_heads, total_num_pages, page_size, head_dim] jax.Array.
89+
slot: Slot of batch size in decode.
90+
seq_len: Prefill tokens seqeunce length.
91+
sharding: Decode cache sharding.
92+
93+
94+
Returns:
95+
Decode cache. List of Tuple K, V. For each K, V:
96+
[num_heads, total_num_pages, page_size, head_dim] jax.Array.
97+
"""
98+
99+
num_pages = self.reserve_pages_insert(slot, seq_len)
100+
padded_caches = self.prefill_cache_padding(
101+
prefill_caches, seq_len, num_pages
102+
)
103+
# Reduce cache batch deminsion
104+
# [kv_heads, seq_len, dim]
105+
squeezed_caches = [
106+
(jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0))
107+
for k, v in padded_caches
108+
]
109+
kv_heads, _, dim = squeezed_caches[0][0].shape
110+
# [kv_heads, num_pages, page_size, dim]
111+
paged_caches = [
112+
(
113+
jnp.reshape(k, (kv_heads, -1, self.page_size, dim)),
114+
jnp.reshape(v, (kv_heads, -1, self.page_size, dim)),
115+
)
116+
for k, v in squeezed_caches
117+
]
118+
update_indexes = self.page_indices[slot, :num_pages]
119+
120+
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
121+
def insert(cache, new_entry):
122+
new_entry = new_entry.squeeze(0)
123+
res = cache.at[:, update_indexes, :, :].set(new_entry)
124+
res = jax.lax.with_sharding_constraint(res, sharding)
125+
return res
126+
127+
caches = [
128+
(insert(k, newk), insert(v, newv))
129+
for (k, v), (newk, newv) in zip(decode_caches, paged_caches)
130+
]
131+
132+
return caches
133+
134+
# pylint: disable-next=all
135+
def get_page_token_indices(self, lens: jax.Array) -> jax.Array:
136+
137+
assert lens.shape == (
138+
self.batch_size,
139+
1,
140+
), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}"
141+
update_page_indices = []
142+
token_scale_indices = []
143+
batch_slots = []
144+
new_lens = []
145+
offset = 0
146+
for slot in range(self.batch_size):
147+
seq_len = lens[slot][0]
148+
num_pages = seq_len // self.page_size + 1
149+
token_pos = seq_len % self.page_size
150+
page_index = self.page_indices[slot, num_pages - 1]
151+
if page_index < 0:
152+
continue
153+
update_page_indices.append(page_index)
154+
token_scale_indices.append(offset + token_pos)
155+
batch_slots.append(slot)
156+
new_lens.append(seq_len + 1)
157+
offset += self.page_size
158+
return jnp.stack(
159+
(
160+
jnp.asarray(update_page_indices),
161+
jnp.asarray(token_scale_indices),
162+
jnp.asarray(batch_slots),
163+
jnp.asarray(new_lens),
164+
)
165+
)
166+
167+
# pylint: disable-next=all
168+
def fill_new_pages(self, lens: jax.Array):
169+
for slot in range(self.batch_size):
170+
self.reserve_pages_decode(slot, lens[slot])
171+
172+
# pylint: disable-next=all
173+
def pad_sequences(self, array, pad_width=10):
174+
padding_config = [
175+
(0, 0),
176+
(0, 0),
177+
(0, pad_width),
178+
(0, 0),
179+
] # Pad only seq_len and dim
180+
padded_array = jnp.pad(array, padding_config, mode="constant")
181+
return padded_array
182+
183+
# pylint: disable-next=all
184+
def free_pages_resource(self, slot):
185+
for i in range(self.max_pages_per_sequence):
186+
index = self.page_indices[slot, i]
187+
if index < 0:
188+
break
189+
self.unused_pages.put(index, block=False)
190+
191+
self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([-1]))
192+
return None

tests/test_kv_cache_manager.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import unittest
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import torch
6+
7+
from jetstream_pt.third_party.llama import model_args
8+
from jetstream_pt import environment
9+
from jetstream_pt.page_attention_manager import PageAttentionManager
10+
from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill
11+
from jetstream_pt import torchjax
12+
from absl.testing import parameterized
13+
14+
15+
class PageAttentnioTest(parameterized.TestCase):
16+
17+
def _make_env(self, bf16_enable=True):
18+
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
19+
torch.set_default_dtype(torch_dtype)
20+
jax.config.update("jax_dynamic_shapes", False)
21+
jax.config.update("jax_traceback_filtering", "off")
22+
jax.config.update("jax_platform_name", "cpu")
23+
config = model_args.get_model_args("tiny", 128, 1, True)
24+
environment_data = environment.JetEngineEnvironmentData()
25+
environment_data.max_input_sequence_length = 128
26+
environment_data.max_input_sequence_length = 128
27+
environment_data.cache_sequence_length = 128
28+
environment_data.bf16_enable = bf16_enable
29+
environment_data.model_type = "llama-2-tiny"
30+
environment_data.batch_size = 3
31+
environment_data.num_layers = config.n_layers
32+
environment_data.cache_shape = (
33+
1,
34+
config.n_kv_heads,
35+
environment_data.cache_sequence_length,
36+
config.dim // config.n_heads,
37+
)
38+
env = environment.JetEngineEnvironment(environment_data)
39+
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
40+
return env, config
41+
42+
def test_page_attention_update(self):
43+
jax.config.update("jax_platform_name", "cpu")
44+
print(f"---------> {jax.devices()}")
45+
46+
env, _ = self._make_env()
47+
48+
pam = PageAttentionManager(
49+
batch_size=5, total_num_pages=20, page_size=4, max_pages_per_sequence=4
50+
)
51+
shape = (1, 20, 4, 2)
52+
decode_caches = []
53+
decode_caches.append(
54+
PageKVCacheGenerate.empty(
55+
shape=shape, device=None, bf16_enable=True, env=env
56+
)
57+
)
58+
decode_caches = [c.state() for c in decode_caches]
59+
60+
self.cache_sharding = env.cache_sharding
61+
62+
def _insert_prefill(seq_len, dim, slot):
63+
prefill_chache = KVCachePrefill()
64+
k, v = jnp.arange(seq_len * dim), jnp.arange(seq_len * dim)
65+
k, v = jnp.reshape(k, (1, 1, seq_len, dim)), jnp.reshape(
66+
k, (1, 1, seq_len, dim)
67+
)
68+
prefill_chache.update(k, v, 0)
69+
prefill_caches = [prefill_chache]
70+
prefill_caches = [c.state() for c in prefill_caches]
71+
72+
return pam.insert_prefill_cache(
73+
prefill_caches, decode_caches, slot, seq_len, env.cache_sharding
74+
)
75+
76+
decode_caches = _insert_prefill(3, 2, 0)
77+
decode_caches = _insert_prefill(8, 2, 1)
78+
decode_caches = _insert_prefill(13, 2, 3)
79+
80+
lens = jnp.asarray([3, 8, 0, 13, 0]).reshape(5, 1)
81+
pam.fill_new_pages(lens)
82+
page_token_indices = pam.get_page_token_indices(lens)
83+
page_token_indices = torchjax.to_torch(page_token_indices)
84+
85+
caches_obj = [
86+
PageKVCacheGenerate(
87+
k, v, page_token_indices, self.cache_sharding, env=env
88+
)
89+
for k, v in torchjax.to_torch(decode_caches)
90+
]
91+
xk, xv = jnp.arange(-1, -11, -1).reshape(5, 1, 1, 2), jnp.arange(
92+
-1, -11, -1
93+
).reshape(5, 1, 1, 2)
94+
xk = torchjax.to_torch(xk)
95+
xv = torchjax.to_torch(xv)
96+
decode_caches = caches_obj[0].update(xk, xv)
97+
expected = jnp.asarray([[0, 1], [2, 3], [4, 5], [-1, -2]])
98+
self.assertTrue(jnp.array_equal(decode_caches[0][0][0], expected))
99+
expected = jnp.asarray([[-3, -4], [0, 0], [0, 0], [0, 0]])
100+
self.assertTrue(jnp.array_equal(decode_caches[0][0][7], expected))
101+
expected = jnp.asarray([[24, 25], [-7, -8], [0, 0], [0, 0]])
102+
self.assertTrue(jnp.array_equal(decode_caches[0][0][6], expected))
103+
104+
105+
if __name__ == "__main__":
106+
unittest.main()

0 commit comments

Comments
 (0)