-
Notifications
You must be signed in to change notification settings - Fork 17
Add page attention manager and kvcache manager #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
FanhaiLu1
merged 4 commits into
AI-Hypercomputer:main
from
FanhaiLu1:pa_decode_checkin_2
Aug 6, 2024
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import queue | ||
import functools | ||
from typing import List, Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.sharding as jsharding | ||
|
||
|
||
class PageAttentionManager: | ||
"""Manages page blocks. | ||
|
||
This manager maintains a main list of free page blocks, it support below features: | ||
1. Reseve pages for prefill insert and decode. | ||
2. Free pages resource for the slots after decode. Pages indices go to free list. | ||
3. Get pages indices meta data for all the slots. | ||
4. Transform and insert prefill caches to decode caches. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
batch_size: int, | ||
total_num_pages: int, | ||
page_size: int, | ||
max_pages_per_sequence: int, | ||
): | ||
self.unused_pages = queue.Queue() | ||
self.batch_size = batch_size | ||
self.page_indices = jnp.full( | ||
(batch_size, max_pages_per_sequence), -1, dtype=jnp.int32 | ||
) | ||
self.lengths = jnp.zeros(batch_size, dtype=jnp.int32) | ||
self.page_size = page_size | ||
self.max_pages_per_sequence = max_pages_per_sequence | ||
for i in range(total_num_pages): | ||
self.unused_pages.put(i, block=False) | ||
|
||
# pylint: disable-next=all | ||
def reserve_pages_insert(self, slot: int, seq_len: int) -> Tuple[int, list]: | ||
self.lengths = self.lengths.at[slot].set(seq_len) | ||
num_pages = seq_len // self.page_size | ||
if seq_len % self.page_size != 0: | ||
num_pages = num_pages + 1 | ||
|
||
indices = [self.unused_pages.get(block=False) for _ in range(num_pages)] | ||
self.page_indices = self.page_indices.at[slot, :num_pages].set(indices) | ||
return num_pages | ||
|
||
# pylint: disable-next=all | ||
def reserve_pages_decode(self, slot: int, seq_len: int): | ||
if seq_len > 0 and seq_len % self.page_size == 0: | ||
index = self.unused_pages.get(block=False) | ||
num_pages = seq_len // self.page_size | ||
self.page_indices = self.page_indices.at[slot, num_pages].set(index) | ||
|
||
# pylint: disable-next=all | ||
def prefill_cache_padding( | ||
self, | ||
caches: List[Tuple[jax.Array, jax.Array]], | ||
seq_len: int, | ||
num_pages: int, | ||
) -> List[Tuple[jax.Array, jax.Array]]: | ||
|
||
pad_width = num_pages * self.page_size - seq_len | ||
if pad_width == 0: | ||
return caches | ||
|
||
caches = [ | ||
(self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width)) | ||
for k, v in caches | ||
] | ||
return caches | ||
|
||
def insert_prefill_cache( | ||
self, | ||
prefill_caches: List[Tuple[jax.Array, jax.Array]], | ||
decode_caches: List[Tuple[jax.Array, jax.Array]], | ||
slot: int, | ||
seq_len: int, | ||
sharding: jsharding.Sharding, | ||
) -> List[Tuple[jax.Array, jax.Array]]: | ||
"""Insert prefill caches to decode caches slot. | ||
|
||
Args: | ||
prefill_caches: List of Tuple K, V. For each K, V: | ||
[batch_size, num_heads, seq_len, head_dim] jax.Array. | ||
decode_caches: List of Tuple K, V. For each K, V: | ||
[num_heads, total_num_pages, page_size, head_dim] jax.Array. | ||
slot: Slot of batch size in decode. | ||
seq_len: Prefill tokens seqeunce length. | ||
sharding: Decode cache sharding. | ||
|
||
|
||
Returns: | ||
Decode cache. List of Tuple K, V. For each K, V: | ||
[num_heads, total_num_pages, page_size, head_dim] jax.Array. | ||
""" | ||
|
||
num_pages = self.reserve_pages_insert(slot, seq_len) | ||
padded_caches = self.prefill_cache_padding( | ||
prefill_caches, seq_len, num_pages | ||
) | ||
# Reduce cache batch deminsion | ||
# [kv_heads, seq_len, dim] | ||
squeezed_caches = [ | ||
(jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0)) | ||
for k, v in padded_caches | ||
] | ||
kv_heads, _, dim = squeezed_caches[0][0].shape | ||
# [kv_heads, num_pages, page_size, dim] | ||
paged_caches = [ | ||
( | ||
jnp.reshape(k, (kv_heads, -1, self.page_size, dim)), | ||
jnp.reshape(v, (kv_heads, -1, self.page_size, dim)), | ||
) | ||
for k, v in squeezed_caches | ||
] | ||
update_indexes = self.page_indices[slot, :num_pages] | ||
|
||
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) | ||
def insert(cache, new_entry): | ||
new_entry = new_entry.squeeze(0) | ||
res = cache.at[:, update_indexes, :, :].set(new_entry) | ||
res = jax.lax.with_sharding_constraint(res, sharding) | ||
return res | ||
|
||
caches = [ | ||
(insert(k, newk), insert(v, newv)) | ||
for (k, v), (newk, newv) in zip(decode_caches, paged_caches) | ||
] | ||
|
||
return caches | ||
|
||
# pylint: disable-next=all | ||
def get_page_token_indices(self, lens: jax.Array) -> jax.Array: | ||
|
||
assert lens.shape == ( | ||
self.batch_size, | ||
1, | ||
), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}" | ||
update_page_indices = [] | ||
token_scale_indices = [] | ||
batch_slots = [] | ||
new_lens = [] | ||
offset = 0 | ||
for slot in range(self.batch_size): | ||
seq_len = lens[slot][0] | ||
num_pages = seq_len // self.page_size + 1 | ||
token_pos = seq_len % self.page_size | ||
page_index = self.page_indices[slot, num_pages - 1] | ||
if page_index < 0: | ||
continue | ||
update_page_indices.append(page_index) | ||
token_scale_indices.append(offset + token_pos) | ||
batch_slots.append(slot) | ||
new_lens.append(seq_len + 1) | ||
offset += self.page_size | ||
return jnp.stack( | ||
( | ||
jnp.asarray(update_page_indices), | ||
jnp.asarray(token_scale_indices), | ||
jnp.asarray(batch_slots), | ||
jnp.asarray(new_lens), | ||
) | ||
) | ||
|
||
# pylint: disable-next=all | ||
def fill_new_pages(self, lens: jax.Array): | ||
for slot in range(self.batch_size): | ||
self.reserve_pages_decode(slot, lens[slot]) | ||
|
||
# pylint: disable-next=all | ||
def pad_sequences(self, array, pad_width=10): | ||
padding_config = [ | ||
(0, 0), | ||
(0, 0), | ||
(0, pad_width), | ||
(0, 0), | ||
] # Pad only seq_len and dim | ||
padded_array = jnp.pad(array, padding_config, mode="constant") | ||
return padded_array | ||
|
||
# pylint: disable-next=all | ||
def free_pages_resource(self, slot): | ||
for i in range(self.max_pages_per_sequence): | ||
index = self.page_indices[slot, i] | ||
if index < 0: | ||
break | ||
self.unused_pages.put(index, block=False) | ||
|
||
self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([-1])) | ||
return None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import unittest | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import torch | ||
|
||
from jetstream_pt.third_party.llama import model_args | ||
from jetstream_pt import environment | ||
from jetstream_pt.page_attention_manager import PageAttentionManager | ||
from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill | ||
from jetstream_pt import torchjax | ||
from absl.testing import parameterized | ||
|
||
|
||
class PageAttentnioTest(parameterized.TestCase): | ||
|
||
def _make_env(self, bf16_enable=True): | ||
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 | ||
torch.set_default_dtype(torch_dtype) | ||
jax.config.update("jax_dynamic_shapes", False) | ||
jax.config.update("jax_traceback_filtering", "off") | ||
jax.config.update("jax_platform_name", "cpu") | ||
config = model_args.get_model_args("tiny", 128, 1, True) | ||
environment_data = environment.JetEngineEnvironmentData() | ||
environment_data.max_input_sequence_length = 128 | ||
environment_data.max_input_sequence_length = 128 | ||
environment_data.cache_sequence_length = 128 | ||
environment_data.bf16_enable = bf16_enable | ||
environment_data.model_type = "llama-2-tiny" | ||
environment_data.batch_size = 3 | ||
environment_data.num_layers = config.n_layers | ||
environment_data.cache_shape = ( | ||
1, | ||
config.n_kv_heads, | ||
environment_data.cache_sequence_length, | ||
config.dim // config.n_heads, | ||
) | ||
env = environment.JetEngineEnvironment(environment_data) | ||
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu | ||
return env, config | ||
|
||
def test_page_attention_update(self): | ||
jax.config.update("jax_platform_name", "cpu") | ||
print(f"---------> {jax.devices()}") | ||
|
||
env, _ = self._make_env() | ||
|
||
pam = PageAttentionManager( | ||
batch_size=5, total_num_pages=20, page_size=4, max_pages_per_sequence=4 | ||
) | ||
shape = (1, 20, 4, 2) | ||
decode_caches = [] | ||
decode_caches.append( | ||
PageKVCacheGenerate.empty( | ||
shape=shape, device=None, bf16_enable=True, env=env | ||
) | ||
) | ||
decode_caches = [c.state() for c in decode_caches] | ||
|
||
self.cache_sharding = env.cache_sharding | ||
|
||
def _insert_prefill(seq_len, dim, slot): | ||
prefill_chache = KVCachePrefill() | ||
k, v = jnp.arange(seq_len * dim), jnp.arange(seq_len * dim) | ||
k, v = jnp.reshape(k, (1, 1, seq_len, dim)), jnp.reshape( | ||
k, (1, 1, seq_len, dim) | ||
) | ||
prefill_chache.update(k, v, 0) | ||
prefill_caches = [prefill_chache] | ||
prefill_caches = [c.state() for c in prefill_caches] | ||
|
||
return pam.insert_prefill_cache( | ||
prefill_caches, decode_caches, slot, seq_len, env.cache_sharding | ||
) | ||
|
||
decode_caches = _insert_prefill(3, 2, 0) | ||
decode_caches = _insert_prefill(8, 2, 1) | ||
decode_caches = _insert_prefill(13, 2, 3) | ||
|
||
lens = jnp.asarray([3, 8, 0, 13, 0]).reshape(5, 1) | ||
pam.fill_new_pages(lens) | ||
page_token_indices = pam.get_page_token_indices(lens) | ||
page_token_indices = torchjax.to_torch(page_token_indices) | ||
|
||
caches_obj = [ | ||
PageKVCacheGenerate( | ||
k, v, page_token_indices, self.cache_sharding, env=env | ||
) | ||
for k, v in torchjax.to_torch(decode_caches) | ||
] | ||
xk, xv = jnp.arange(-1, -11, -1).reshape(5, 1, 1, 2), jnp.arange( | ||
-1, -11, -1 | ||
).reshape(5, 1, 1, 2) | ||
xk = torchjax.to_torch(xk) | ||
xv = torchjax.to_torch(xv) | ||
decode_caches = caches_obj[0].update(xk, xv) | ||
expected = jnp.asarray([[0, 1], [2, 3], [4, 5], [-1, -2]]) | ||
self.assertTrue(jnp.array_equal(decode_caches[0][0][0], expected)) | ||
expected = jnp.asarray([[-3, -4], [0, 0], [0, 0], [0, 0]]) | ||
self.assertTrue(jnp.array_equal(decode_caches[0][0][7], expected)) | ||
expected = jnp.asarray([[24, 25], [-7, -8], [0, 0], [0, 0]]) | ||
self.assertTrue(jnp.array_equal(decode_caches[0][0][6], expected)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use
deque
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current jet stream implementation's detokenize_threads and _generate_threads are different thread, both of them need to access this queue. So the queue should be thread safe, but deque is not thread safe.