Skip to content

Add page attention manager and kvcache manager #167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions jetstream_pt/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,61 @@ def finalize(self):
self.new_v_scaler,
self.input_pos,
)


class PageKVCacheGenerate:
"""Page attention kvache generator without quantization"""

def __init__(
self,
cache_k: torch.Tensor, # previous cache
cache_v: torch.Tensor, # previous cache
page_token_indices: torch.Tensor, # page and token indices for the cache
sharding,
env=None,
):
super().__init__()
self.cache_k = cache_k
self.cache_v = cache_v
self.page_token_indices = page_token_indices
self.sharding = sharding
self.env = env

def update(self, key, value):
"""Update kv cache"""
keyj, valuej, page_token_indicesj = torchjax.from_torch(
(key, value, self.page_token_indices)
)

def _update(cache, x):
x = x.squeeze(2).transpose((1, 0, 2))
x = x[:, page_token_indicesj[2], :]
head, _, page_size, dim = cache.shape
selected_cache = cache[:, page_token_indicesj[0], :, :]
selected_cache = selected_cache.reshape((head, -1, dim))

selected_cache = selected_cache.at[:, page_token_indicesj[1], :].set(x)
selected_cache = selected_cache.reshape((head, -1, page_size, dim))

cache = cache.at[:, page_token_indicesj[0], :, :].set(selected_cache)
return cache

# pylint: disable-next=all
self.cache_k._elem = _update(self.cache_k._elem, keyj)
# pylint: disable-next=all
self.cache_k._elem = _update(self.cache_v._elem, valuej)
return self.cache_k, self.cache_v

def state(self):
"""Get kv cache state"""
# pylint: disable-next=all
return self.cache_k.jax(), self.cache_v.jax()

@classmethod
def empty(cls, shape, device, bf16_enable, env):
"""Create empty kv caches"""
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
k = jnp.zeros(shape, device=device, dtype=default_dtype)
v = jnp.zeros(shape, device=device, dtype=default_dtype)
k, v = torchjax.to_torch((k, v))
return cls(k, v, None, device, env=env)
192 changes: 192 additions & 0 deletions jetstream_pt/page_attention_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import queue
import functools
from typing import List, Tuple

import jax
import jax.numpy as jnp
import jax.sharding as jsharding


class PageAttentionManager:
"""Manages page blocks.

This manager maintains a main list of free page blocks, it support below features:
1. Reseve pages for prefill insert and decode.
2. Free pages resource for the slots after decode. Pages indices go to free list.
3. Get pages indices meta data for all the slots.
4. Transform and insert prefill caches to decode caches.
"""

def __init__(
self,
batch_size: int,
total_num_pages: int,
page_size: int,
max_pages_per_sequence: int,
):
self.unused_pages = queue.Queue()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use deque

Copy link
Collaborator Author

@FanhaiLu1 FanhaiLu1 Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current jet stream implementation's detokenize_threads and _generate_threads are different thread, both of them need to access this queue. So the queue should be thread safe, but deque is not thread safe.

self.batch_size = batch_size
self.page_indices = jnp.full(
(batch_size, max_pages_per_sequence), -1, dtype=jnp.int32
)
self.lengths = jnp.zeros(batch_size, dtype=jnp.int32)
self.page_size = page_size
self.max_pages_per_sequence = max_pages_per_sequence
for i in range(total_num_pages):
self.unused_pages.put(i, block=False)

# pylint: disable-next=all
def reserve_pages_insert(self, slot: int, seq_len: int) -> Tuple[int, list]:
self.lengths = self.lengths.at[slot].set(seq_len)
num_pages = seq_len // self.page_size
if seq_len % self.page_size != 0:
num_pages = num_pages + 1

indices = [self.unused_pages.get(block=False) for _ in range(num_pages)]
self.page_indices = self.page_indices.at[slot, :num_pages].set(indices)
return num_pages

# pylint: disable-next=all
def reserve_pages_decode(self, slot: int, seq_len: int):
if seq_len > 0 and seq_len % self.page_size == 0:
index = self.unused_pages.get(block=False)
num_pages = seq_len // self.page_size
self.page_indices = self.page_indices.at[slot, num_pages].set(index)

# pylint: disable-next=all
def prefill_cache_padding(
self,
caches: List[Tuple[jax.Array, jax.Array]],
seq_len: int,
num_pages: int,
) -> List[Tuple[jax.Array, jax.Array]]:

pad_width = num_pages * self.page_size - seq_len
if pad_width == 0:
return caches

caches = [
(self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width))
for k, v in caches
]
return caches

def insert_prefill_cache(
self,
prefill_caches: List[Tuple[jax.Array, jax.Array]],
decode_caches: List[Tuple[jax.Array, jax.Array]],
slot: int,
seq_len: int,
sharding: jsharding.Sharding,
) -> List[Tuple[jax.Array, jax.Array]]:
"""Insert prefill caches to decode caches slot.

Args:
prefill_caches: List of Tuple K, V. For each K, V:
[batch_size, num_heads, seq_len, head_dim] jax.Array.
decode_caches: List of Tuple K, V. For each K, V:
[num_heads, total_num_pages, page_size, head_dim] jax.Array.
slot: Slot of batch size in decode.
seq_len: Prefill tokens seqeunce length.
sharding: Decode cache sharding.


Returns:
Decode cache. List of Tuple K, V. For each K, V:
[num_heads, total_num_pages, page_size, head_dim] jax.Array.
"""

num_pages = self.reserve_pages_insert(slot, seq_len)
padded_caches = self.prefill_cache_padding(
prefill_caches, seq_len, num_pages
)
# Reduce cache batch deminsion
# [kv_heads, seq_len, dim]
squeezed_caches = [
(jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0))
for k, v in padded_caches
]
kv_heads, _, dim = squeezed_caches[0][0].shape
# [kv_heads, num_pages, page_size, dim]
paged_caches = [
(
jnp.reshape(k, (kv_heads, -1, self.page_size, dim)),
jnp.reshape(v, (kv_heads, -1, self.page_size, dim)),
)
for k, v in squeezed_caches
]
update_indexes = self.page_indices[slot, :num_pages]

@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
def insert(cache, new_entry):
new_entry = new_entry.squeeze(0)
res = cache.at[:, update_indexes, :, :].set(new_entry)
res = jax.lax.with_sharding_constraint(res, sharding)
return res

caches = [
(insert(k, newk), insert(v, newv))
for (k, v), (newk, newv) in zip(decode_caches, paged_caches)
]

return caches

# pylint: disable-next=all
def get_page_token_indices(self, lens: jax.Array) -> jax.Array:

assert lens.shape == (
self.batch_size,
1,
), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}"
update_page_indices = []
token_scale_indices = []
batch_slots = []
new_lens = []
offset = 0
for slot in range(self.batch_size):
seq_len = lens[slot][0]
num_pages = seq_len // self.page_size + 1
token_pos = seq_len % self.page_size
page_index = self.page_indices[slot, num_pages - 1]
if page_index < 0:
continue
update_page_indices.append(page_index)
token_scale_indices.append(offset + token_pos)
batch_slots.append(slot)
new_lens.append(seq_len + 1)
offset += self.page_size
return jnp.stack(
(
jnp.asarray(update_page_indices),
jnp.asarray(token_scale_indices),
jnp.asarray(batch_slots),
jnp.asarray(new_lens),
)
)

# pylint: disable-next=all
def fill_new_pages(self, lens: jax.Array):
for slot in range(self.batch_size):
self.reserve_pages_decode(slot, lens[slot])

# pylint: disable-next=all
def pad_sequences(self, array, pad_width=10):
padding_config = [
(0, 0),
(0, 0),
(0, pad_width),
(0, 0),
] # Pad only seq_len and dim
padded_array = jnp.pad(array, padding_config, mode="constant")
return padded_array

# pylint: disable-next=all
def free_pages_resource(self, slot):
for i in range(self.max_pages_per_sequence):
index = self.page_indices[slot, i]
if index < 0:
break
self.unused_pages.put(index, block=False)

self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([-1]))
return None
106 changes: 106 additions & 0 deletions tests/test_kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import unittest

import jax
import jax.numpy as jnp
import torch

from jetstream_pt.third_party.llama import model_args
from jetstream_pt import environment
from jetstream_pt.page_attention_manager import PageAttentionManager
from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill
from jetstream_pt import torchjax
from absl.testing import parameterized


class PageAttentnioTest(parameterized.TestCase):

def _make_env(self, bf16_enable=True):
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
torch.set_default_dtype(torch_dtype)
jax.config.update("jax_dynamic_shapes", False)
jax.config.update("jax_traceback_filtering", "off")
jax.config.update("jax_platform_name", "cpu")
config = model_args.get_model_args("tiny", 128, 1, True)
environment_data = environment.JetEngineEnvironmentData()
environment_data.max_input_sequence_length = 128
environment_data.max_input_sequence_length = 128
environment_data.cache_sequence_length = 128
environment_data.bf16_enable = bf16_enable
environment_data.model_type = "llama-2-tiny"
environment_data.batch_size = 3
environment_data.num_layers = config.n_layers
environment_data.cache_shape = (
1,
config.n_kv_heads,
environment_data.cache_sequence_length,
config.dim // config.n_heads,
)
env = environment.JetEngineEnvironment(environment_data)
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
return env, config

def test_page_attention_update(self):
jax.config.update("jax_platform_name", "cpu")
print(f"---------> {jax.devices()}")

env, _ = self._make_env()

pam = PageAttentionManager(
batch_size=5, total_num_pages=20, page_size=4, max_pages_per_sequence=4
)
shape = (1, 20, 4, 2)
decode_caches = []
decode_caches.append(
PageKVCacheGenerate.empty(
shape=shape, device=None, bf16_enable=True, env=env
)
)
decode_caches = [c.state() for c in decode_caches]

self.cache_sharding = env.cache_sharding

def _insert_prefill(seq_len, dim, slot):
prefill_chache = KVCachePrefill()
k, v = jnp.arange(seq_len * dim), jnp.arange(seq_len * dim)
k, v = jnp.reshape(k, (1, 1, seq_len, dim)), jnp.reshape(
k, (1, 1, seq_len, dim)
)
prefill_chache.update(k, v, 0)
prefill_caches = [prefill_chache]
prefill_caches = [c.state() for c in prefill_caches]

return pam.insert_prefill_cache(
prefill_caches, decode_caches, slot, seq_len, env.cache_sharding
)

decode_caches = _insert_prefill(3, 2, 0)
decode_caches = _insert_prefill(8, 2, 1)
decode_caches = _insert_prefill(13, 2, 3)

lens = jnp.asarray([3, 8, 0, 13, 0]).reshape(5, 1)
pam.fill_new_pages(lens)
page_token_indices = pam.get_page_token_indices(lens)
page_token_indices = torchjax.to_torch(page_token_indices)

caches_obj = [
PageKVCacheGenerate(
k, v, page_token_indices, self.cache_sharding, env=env
)
for k, v in torchjax.to_torch(decode_caches)
]
xk, xv = jnp.arange(-1, -11, -1).reshape(5, 1, 1, 2), jnp.arange(
-1, -11, -1
).reshape(5, 1, 1, 2)
xk = torchjax.to_torch(xk)
xv = torchjax.to_torch(xv)
decode_caches = caches_obj[0].update(xk, xv)
expected = jnp.asarray([[0, 1], [2, 3], [4, 5], [-1, -2]])
self.assertTrue(jnp.array_equal(decode_caches[0][0][0], expected))
expected = jnp.asarray([[-3, -4], [0, 0], [0, 0], [0, 0]])
self.assertTrue(jnp.array_equal(decode_caches[0][0][7], expected))
expected = jnp.asarray([[24, 25], [-7, -8], [0, 0], [0, 0]])
self.assertTrue(jnp.array_equal(decode_caches[0][0][6], expected))


if __name__ == "__main__":
unittest.main()
Loading
Loading