diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 76f4412..9a44a47 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -654,3 +654,61 @@ def finalize(self): self.new_v_scaler, self.input_pos, ) + + +class PageKVCacheGenerate: + """Page attention kvache generator without quantization""" + + def __init__( + self, + cache_k: torch.Tensor, # previous cache + cache_v: torch.Tensor, # previous cache + page_token_indices: torch.Tensor, # page and token indices for the cache + sharding, + env=None, + ): + super().__init__() + self.cache_k = cache_k + self.cache_v = cache_v + self.page_token_indices = page_token_indices + self.sharding = sharding + self.env = env + + def update(self, key, value): + """Update kv cache""" + keyj, valuej, page_token_indicesj = torchjax.from_torch( + (key, value, self.page_token_indices) + ) + + def _update(cache, x): + x = x.squeeze(2).transpose((1, 0, 2)) + x = x[:, page_token_indicesj[2], :] + head, _, page_size, dim = cache.shape + selected_cache = cache[:, page_token_indicesj[0], :, :] + selected_cache = selected_cache.reshape((head, -1, dim)) + + selected_cache = selected_cache.at[:, page_token_indicesj[1], :].set(x) + selected_cache = selected_cache.reshape((head, -1, page_size, dim)) + + cache = cache.at[:, page_token_indicesj[0], :, :].set(selected_cache) + return cache + + # pylint: disable-next=all + self.cache_k._elem = _update(self.cache_k._elem, keyj) + # pylint: disable-next=all + self.cache_k._elem = _update(self.cache_v._elem, valuej) + return self.cache_k, self.cache_v + + def state(self): + """Get kv cache state""" + # pylint: disable-next=all + return self.cache_k.jax(), self.cache_v.jax() + + @classmethod + def empty(cls, shape, device, bf16_enable, env): + """Create empty kv caches""" + default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 + k = jnp.zeros(shape, device=device, dtype=default_dtype) + v = jnp.zeros(shape, device=device, dtype=default_dtype) + k, v = torchjax.to_torch((k, v)) + return cls(k, v, None, device, env=env) diff --git a/jetstream_pt/page_attention_manager.py b/jetstream_pt/page_attention_manager.py new file mode 100644 index 0000000..ab2d2c9 --- /dev/null +++ b/jetstream_pt/page_attention_manager.py @@ -0,0 +1,192 @@ +import queue +import functools +from typing import List, Tuple + +import jax +import jax.numpy as jnp +import jax.sharding as jsharding + + +class PageAttentionManager: + """Manages page blocks. + + This manager maintains a main list of free page blocks, it support below features: + 1. Reseve pages for prefill insert and decode. + 2. Free pages resource for the slots after decode. Pages indices go to free list. + 3. Get pages indices meta data for all the slots. + 4. Transform and insert prefill caches to decode caches. + """ + + def __init__( + self, + batch_size: int, + total_num_pages: int, + page_size: int, + max_pages_per_sequence: int, + ): + self.unused_pages = queue.Queue() + self.batch_size = batch_size + self.page_indices = jnp.full( + (batch_size, max_pages_per_sequence), -1, dtype=jnp.int32 + ) + self.lengths = jnp.zeros(batch_size, dtype=jnp.int32) + self.page_size = page_size + self.max_pages_per_sequence = max_pages_per_sequence + for i in range(total_num_pages): + self.unused_pages.put(i, block=False) + + # pylint: disable-next=all + def reserve_pages_insert(self, slot: int, seq_len: int) -> Tuple[int, list]: + self.lengths = self.lengths.at[slot].set(seq_len) + num_pages = seq_len // self.page_size + if seq_len % self.page_size != 0: + num_pages = num_pages + 1 + + indices = [self.unused_pages.get(block=False) for _ in range(num_pages)] + self.page_indices = self.page_indices.at[slot, :num_pages].set(indices) + return num_pages + + # pylint: disable-next=all + def reserve_pages_decode(self, slot: int, seq_len: int): + if seq_len > 0 and seq_len % self.page_size == 0: + index = self.unused_pages.get(block=False) + num_pages = seq_len // self.page_size + self.page_indices = self.page_indices.at[slot, num_pages].set(index) + + # pylint: disable-next=all + def prefill_cache_padding( + self, + caches: List[Tuple[jax.Array, jax.Array]], + seq_len: int, + num_pages: int, + ) -> List[Tuple[jax.Array, jax.Array]]: + + pad_width = num_pages * self.page_size - seq_len + if pad_width == 0: + return caches + + caches = [ + (self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width)) + for k, v in caches + ] + return caches + + def insert_prefill_cache( + self, + prefill_caches: List[Tuple[jax.Array, jax.Array]], + decode_caches: List[Tuple[jax.Array, jax.Array]], + slot: int, + seq_len: int, + sharding: jsharding.Sharding, + ) -> List[Tuple[jax.Array, jax.Array]]: + """Insert prefill caches to decode caches slot. + + Args: + prefill_caches: List of Tuple K, V. For each K, V: + [batch_size, num_heads, seq_len, head_dim] jax.Array. + decode_caches: List of Tuple K, V. For each K, V: + [num_heads, total_num_pages, page_size, head_dim] jax.Array. + slot: Slot of batch size in decode. + seq_len: Prefill tokens seqeunce length. + sharding: Decode cache sharding. + + + Returns: + Decode cache. List of Tuple K, V. For each K, V: + [num_heads, total_num_pages, page_size, head_dim] jax.Array. + """ + + num_pages = self.reserve_pages_insert(slot, seq_len) + padded_caches = self.prefill_cache_padding( + prefill_caches, seq_len, num_pages + ) + # Reduce cache batch deminsion + # [kv_heads, seq_len, dim] + squeezed_caches = [ + (jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0)) + for k, v in padded_caches + ] + kv_heads, _, dim = squeezed_caches[0][0].shape + # [kv_heads, num_pages, page_size, dim] + paged_caches = [ + ( + jnp.reshape(k, (kv_heads, -1, self.page_size, dim)), + jnp.reshape(v, (kv_heads, -1, self.page_size, dim)), + ) + for k, v in squeezed_caches + ] + update_indexes = self.page_indices[slot, :num_pages] + + @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) + def insert(cache, new_entry): + new_entry = new_entry.squeeze(0) + res = cache.at[:, update_indexes, :, :].set(new_entry) + res = jax.lax.with_sharding_constraint(res, sharding) + return res + + caches = [ + (insert(k, newk), insert(v, newv)) + for (k, v), (newk, newv) in zip(decode_caches, paged_caches) + ] + + return caches + + # pylint: disable-next=all + def get_page_token_indices(self, lens: jax.Array) -> jax.Array: + + assert lens.shape == ( + self.batch_size, + 1, + ), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}" + update_page_indices = [] + token_scale_indices = [] + batch_slots = [] + new_lens = [] + offset = 0 + for slot in range(self.batch_size): + seq_len = lens[slot][0] + num_pages = seq_len // self.page_size + 1 + token_pos = seq_len % self.page_size + page_index = self.page_indices[slot, num_pages - 1] + if page_index < 0: + continue + update_page_indices.append(page_index) + token_scale_indices.append(offset + token_pos) + batch_slots.append(slot) + new_lens.append(seq_len + 1) + offset += self.page_size + return jnp.stack( + ( + jnp.asarray(update_page_indices), + jnp.asarray(token_scale_indices), + jnp.asarray(batch_slots), + jnp.asarray(new_lens), + ) + ) + + # pylint: disable-next=all + def fill_new_pages(self, lens: jax.Array): + for slot in range(self.batch_size): + self.reserve_pages_decode(slot, lens[slot]) + + # pylint: disable-next=all + def pad_sequences(self, array, pad_width=10): + padding_config = [ + (0, 0), + (0, 0), + (0, pad_width), + (0, 0), + ] # Pad only seq_len and dim + padded_array = jnp.pad(array, padding_config, mode="constant") + return padded_array + + # pylint: disable-next=all + def free_pages_resource(self, slot): + for i in range(self.max_pages_per_sequence): + index = self.page_indices[slot, i] + if index < 0: + break + self.unused_pages.put(index, block=False) + + self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([-1])) + return None diff --git a/tests/test_kv_cache_manager.py b/tests/test_kv_cache_manager.py new file mode 100644 index 0000000..85df2a4 --- /dev/null +++ b/tests/test_kv_cache_manager.py @@ -0,0 +1,106 @@ +import unittest + +import jax +import jax.numpy as jnp +import torch + +from jetstream_pt.third_party.llama import model_args +from jetstream_pt import environment +from jetstream_pt.page_attention_manager import PageAttentionManager +from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill +from jetstream_pt import torchjax +from absl.testing import parameterized + + +class PageAttentnioTest(parameterized.TestCase): + + def _make_env(self, bf16_enable=True): + torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 + torch.set_default_dtype(torch_dtype) + jax.config.update("jax_dynamic_shapes", False) + jax.config.update("jax_traceback_filtering", "off") + jax.config.update("jax_platform_name", "cpu") + config = model_args.get_model_args("tiny", 128, 1, True) + environment_data = environment.JetEngineEnvironmentData() + environment_data.max_input_sequence_length = 128 + environment_data.max_input_sequence_length = 128 + environment_data.cache_sequence_length = 128 + environment_data.bf16_enable = bf16_enable + environment_data.model_type = "llama-2-tiny" + environment_data.batch_size = 3 + environment_data.num_layers = config.n_layers + environment_data.cache_shape = ( + 1, + config.n_kv_heads, + environment_data.cache_sequence_length, + config.dim // config.n_heads, + ) + env = environment.JetEngineEnvironment(environment_data) + env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + return env, config + + def test_page_attention_update(self): + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + env, _ = self._make_env() + + pam = PageAttentionManager( + batch_size=5, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + ) + shape = (1, 20, 4, 2) + decode_caches = [] + decode_caches.append( + PageKVCacheGenerate.empty( + shape=shape, device=None, bf16_enable=True, env=env + ) + ) + decode_caches = [c.state() for c in decode_caches] + + self.cache_sharding = env.cache_sharding + + def _insert_prefill(seq_len, dim, slot): + prefill_chache = KVCachePrefill() + k, v = jnp.arange(seq_len * dim), jnp.arange(seq_len * dim) + k, v = jnp.reshape(k, (1, 1, seq_len, dim)), jnp.reshape( + k, (1, 1, seq_len, dim) + ) + prefill_chache.update(k, v, 0) + prefill_caches = [prefill_chache] + prefill_caches = [c.state() for c in prefill_caches] + + return pam.insert_prefill_cache( + prefill_caches, decode_caches, slot, seq_len, env.cache_sharding + ) + + decode_caches = _insert_prefill(3, 2, 0) + decode_caches = _insert_prefill(8, 2, 1) + decode_caches = _insert_prefill(13, 2, 3) + + lens = jnp.asarray([3, 8, 0, 13, 0]).reshape(5, 1) + pam.fill_new_pages(lens) + page_token_indices = pam.get_page_token_indices(lens) + page_token_indices = torchjax.to_torch(page_token_indices) + + caches_obj = [ + PageKVCacheGenerate( + k, v, page_token_indices, self.cache_sharding, env=env + ) + for k, v in torchjax.to_torch(decode_caches) + ] + xk, xv = jnp.arange(-1, -11, -1).reshape(5, 1, 1, 2), jnp.arange( + -1, -11, -1 + ).reshape(5, 1, 1, 2) + xk = torchjax.to_torch(xk) + xv = torchjax.to_torch(xv) + decode_caches = caches_obj[0].update(xk, xv) + expected = jnp.asarray([[0, 1], [2, 3], [4, 5], [-1, -2]]) + self.assertTrue(jnp.array_equal(decode_caches[0][0][0], expected)) + expected = jnp.asarray([[-3, -4], [0, 0], [0, 0], [0, 0]]) + self.assertTrue(jnp.array_equal(decode_caches[0][0][7], expected)) + expected = jnp.asarray([[24, 25], [-7, -8], [0, 0], [0, 0]]) + self.assertTrue(jnp.array_equal(decode_caches[0][0][6], expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_page_attention.py b/tests/test_page_attention.py new file mode 100644 index 0000000..4880fc8 --- /dev/null +++ b/tests/test_page_attention.py @@ -0,0 +1,161 @@ +import unittest + +import jax +import jax.numpy as jnp +import torch + +from jetstream_pt.third_party.llama import model_args +from jetstream_pt import environment +from jetstream_pt.page_attention_manager import PageAttentionManager +from jetstream_pt.cache_manager import PageKVCacheGenerate, KVCachePrefill +from absl.testing import parameterized + + +class PageAttentionTest(parameterized.TestCase): + + def _make_env(self, bf16_enable=True): + torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 + torch.set_default_dtype(torch_dtype) + jax.config.update("jax_dynamic_shapes", False) + jax.config.update("jax_traceback_filtering", "off") + jax.config.update("jax_platform_name", "cpu") + config = model_args.get_model_args("tiny", 128, 1, True) + environment_data = environment.JetEngineEnvironmentData() + environment_data.max_input_sequence_length = 128 + environment_data.max_input_sequence_length = 128 + environment_data.cache_sequence_length = 128 + environment_data.bf16_enable = bf16_enable + environment_data.model_type = "llama-2-tiny" + environment_data.batch_size = 3 + environment_data.num_layers = config.n_layers + environment_data.cache_shape = ( + 1, + config.n_kv_heads, + environment_data.cache_sequence_length, + config.dim // config.n_heads, + ) + env = environment.JetEngineEnvironment(environment_data) + env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + return env, config + + def test_prefill_insert(self): + + env, _ = self._make_env() + + pam = PageAttentionManager( + batch_size=3, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + ) + shape = (1, 6, 4, 2) + decode_caches = [] + decode_caches.append( + PageKVCacheGenerate.empty( + shape=shape, device=None, bf16_enable=True, env=env + ) + ) + decode_caches = [c.state() for c in decode_caches] + + prefill_chache = KVCachePrefill() + k, v = jnp.arange(6), jnp.arange(6) + k, v = jnp.reshape(k, (1, 1, 3, 2)), jnp.reshape(k, (1, 1, 3, 2)) + prefill_chache.update(k, v, 0) + prefill_caches = [prefill_chache] + prefill_caches = [c.state() for c in prefill_caches] + + pam.insert_prefill_cache( + prefill_caches, decode_caches, 1, 3, env.x_sharding + ) + + def test_prefill_insert_multiple_pages(self): + + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + env, _ = self._make_env() + + pam = PageAttentionManager( + batch_size=3, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + ) + shape = (1, 20, 4, 2) + decode_caches = [] + decode_caches.append( + PageKVCacheGenerate.empty( + shape=shape, device=None, bf16_enable=True, env=env + ) + ) + decode_caches = [c.state() for c in decode_caches] + + self.cache_sharding = env.cache_sharding + + prefill_chache = KVCachePrefill() + k, v = jnp.arange(12), jnp.arange(12) + k, v = jnp.reshape(k, (1, 1, 6, 2)), jnp.reshape(k, (1, 1, 6, 2)) + prefill_chache.update(k, v, 0) + prefill_caches = [prefill_chache] + prefill_caches = [c.state() for c in prefill_caches] + + decode_caches = pam.insert_prefill_cache( + prefill_caches, decode_caches, 1, 6, env.cache_sharding + ) + self.assertEqual(len(decode_caches), 1) + expected = jnp.arange(16).at[12:16].set([0, 0, 0, 0]).reshape(1, 2, 4, 2) + + updated_k = jax.lax.slice_in_dim(decode_caches[0][0], 0, 2, axis=1) + self.assertTrue(jnp.array_equal(updated_k, expected)) + noupdated_k = jax.lax.slice_in_dim(decode_caches[0][0], 2, 20, axis=1) + self.assertTrue(jnp.array_equal(noupdated_k, jnp.zeros_like(noupdated_k))) + + def test_reserve_pages_decode(self): + + env, _ = self._make_env() + + pam = PageAttentionManager( + batch_size=3, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + ) + slot = 1 + seq_len = 8 + pam.reserve_pages_insert(slot, seq_len) + expected_slot_page_indices = jnp.asarray([0, 1]) + slot_page_indices = pam.page_indices[slot][0:2] + self.assertTrue( + jnp.array_equal(slot_page_indices, expected_slot_page_indices) + ) + + lens = jnp.asarray([0, seq_len, 0]) + pam.fill_new_pages(lens) + expected_slot_page_indices = jnp.asarray([0, 1, 2]) + slot_page_indices = pam.page_indices[slot][0:3] + self.assertTrue( + jnp.array_equal(slot_page_indices, expected_slot_page_indices) + ) + + expected_0_page_indices = jnp.asarray([-1, -1, -1, -1]) + zer0_page_indices = pam.page_indices[0][0:4] + self.assertTrue(jnp.array_equal(zer0_page_indices, expected_0_page_indices)) + + def test_get_page_token_indices(self): + env, _ = self._make_env() + + pam = PageAttentionManager( + batch_size=5, total_num_pages=20, page_size=4, max_pages_per_sequence=4 + ) + pam.reserve_pages_insert(1, 8) + pam.reserve_pages_insert(3, 13) + pam.reserve_pages_insert(0, 3) + + lens = jnp.asarray([3, 8, 0, 13, 0]).reshape(5, 1) + pam.fill_new_pages(lens) + + page_token_indices = pam.get_page_token_indices(lens) + + expected_page_indices = jnp.asarray([6, 7, 5]) + expected_token_indices = jnp.asarray([3, 4, 9]) + self.assertTrue( + jnp.array_equal(page_token_indices[0], expected_page_indices) + ) + self.assertTrue( + jnp.array_equal(page_token_indices[1], expected_token_indices) + ) + + +if __name__ == "__main__": + unittest.main()