From 2e13e19f7300bfbe34ac14e8718a1ddc0805194b Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Thu, 19 Sep 2024 17:21:26 +0000 Subject: [PATCH] Swith to np from jax to improve attention manager performance --- jetstream_pt/engine.py | 11 +++++---- jetstream_pt/page_attention_manager.py | 31 +++++++++++++------------- tests/test_kv_cache_manager.py | 5 +++-- tests/test_page_attention.py | 24 ++++++++++---------- 4 files changed, 37 insertions(+), 34 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 2c7e38e..b7298e1 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -653,9 +653,10 @@ def insert_page_attention_with_reservation( decode_state: DecodeState, slot: int, ) -> DecodeState: - num_pages, update_indexes = ( + num_pages, np_update_indexes = ( self.page_attention_manager.reserve_pages_insert(slot, prefix.seq_len) ) + update_indexes = jnp.array(np_update_indexes) _, kv_heads, _, dim = prefix.caches[0][0].shape tep_kv = jnp.zeros( ( @@ -735,10 +736,12 @@ def generate( def generate_page_attention( self, params: Any, decode_state: DecodeState ) -> tuple[DecodeState, engine_api.ResultTokens]: - self.page_attention_manager.fill_new_pages(decode_state.input_pos) - page_token_indices = self.page_attention_manager.get_page_token_indices( - decode_state.input_pos + np_pos = np.asarray(decode_state.input_pos.block_until_ready()) + self.page_attention_manager.fill_new_pages(np_pos) + np_page_token_indices = self.page_attention_manager.get_page_token_indices( + np_pos ) + page_token_indices = jnp.asarray(np_page_token_indices) new_decode_state, result_tokens = self.generate_jit( params, decode_state, page_token_indices=page_token_indices ) diff --git a/jetstream_pt/page_attention_manager.py b/jetstream_pt/page_attention_manager.py index eb17765..a3fd0a5 100644 --- a/jetstream_pt/page_attention_manager.py +++ b/jetstream_pt/page_attention_manager.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp import jax.sharding as jsharding +import numpy as np class PageAttentionManager: @@ -26,22 +27,20 @@ def __init__( ): self.unused_pages = queue.Queue() self.batch_size = batch_size - self.page_indices = jnp.full( + self.page_indices = np.full( (batch_size, max_pages_per_sequence), paged_attention_total_num_pages - 1, - dtype=jnp.int32, + dtype=np.int32, ) - self.lengths = jnp.zeros(batch_size, dtype=jnp.int32) + self.lengths = np.zeros(batch_size, dtype=np.int32) self.paged_attention_page_size = paged_attention_page_size self.max_pages_per_sequence = max_pages_per_sequence for i in range(paged_attention_total_num_pages): self.unused_pages.put(i, block=False) # pylint: disable-next=all - def reserve_pages_insert( - self, slot: int, seq_len: int - ) -> Tuple[int, jax.Array]: - self.lengths = self.lengths.at[slot].set(seq_len) + def reserve_pages_insert(self, slot: int, seq_len: int): + self.lengths[slot] = seq_len num_pages = ( seq_len // self.paged_attention_page_size if seq_len % self.paged_attention_page_size == 0 @@ -49,7 +48,7 @@ def reserve_pages_insert( ) indices = [self.unused_pages.get(block=False) for _ in range(num_pages)] - self.page_indices = self.page_indices.at[slot, :num_pages].set(indices) + self.page_indices[slot, :num_pages] = indices return num_pages, self.page_indices[slot, :num_pages] # pylint: disable-next=all @@ -57,10 +56,10 @@ def reserve_pages_decode(self, slot: int, seq_len: int): if seq_len > 0 and seq_len % self.paged_attention_page_size == 0: index = self.unused_pages.get(block=False) num_pages = seq_len // self.paged_attention_page_size - self.page_indices = self.page_indices.at[slot, num_pages].set(index) + self.page_indices[slot, num_pages] = index # pylint: disable-next=all - def fill_new_pages(self, lens: jax.Array): + def fill_new_pages(self, lens): for slot in range(self.batch_size): self.reserve_pages_decode(slot, lens[slot]) @@ -143,7 +142,7 @@ def insert(cache, new_entry): return caches # pylint: disable-next=all - def get_page_token_indices(self, lens: jax.Array) -> jax.Array: + def get_page_token_indices(self, lens): # assert lens.shape == ( # self.batch_size, # 1, @@ -165,11 +164,11 @@ def get_page_token_indices(self, lens: jax.Array) -> jax.Array: token_scale_indices.append(offset + token_pos) batch_slots.append(slot) offset += self.paged_attention_page_size - self.lengths = jnp.where(lens == 0, 0, lens + 1) - update_page_indices = jnp.asarray(update_page_indices) - token_scale_indices = jnp.asarray(token_scale_indices) - batch_slots = jnp.asarray(batch_slots) - return jnp.stack( + self.lengths = np.where(lens == 0, 0, lens + 1) + update_page_indices = np.asarray(update_page_indices) + token_scale_indices = np.asarray(token_scale_indices) + batch_slots = np.asarray(batch_slots) + return np.stack( ( update_page_indices, token_scale_indices, diff --git a/tests/test_kv_cache_manager.py b/tests/test_kv_cache_manager.py index 74c7bce..38e90dc 100644 --- a/tests/test_kv_cache_manager.py +++ b/tests/test_kv_cache_manager.py @@ -93,9 +93,10 @@ def _insert_prefill(seq_len, dim, slot): decode_caches = _insert_prefill(8, 2, 1) decode_caches = _insert_prefill(13, 2, 3) - lens = jnp.asarray([3, 8, 0, 13, 0]) + lens = np.asarray([3, 8, 0, 13, 0]) pam.fill_new_pages(lens) - page_token_indices = pam.get_page_token_indices(lens) + np_page_token_indices = pam.get_page_token_indices(lens) + page_token_indices = jnp.asarray(np_page_token_indices) page_token_indices = torchjax.to_torch(page_token_indices) caches_obj = [ diff --git a/tests/test_page_attention.py b/tests/test_page_attention.py index 294cb8d..d531bc4 100644 --- a/tests/test_page_attention.py +++ b/tests/test_page_attention.py @@ -157,23 +157,23 @@ def test_reserve_pages_decode(self): slot = 1 seq_len = 8 pam.reserve_pages_insert(slot, seq_len) - expected_slot_page_indices = jnp.asarray([0, 1]) + expected_slot_page_indices = np.asarray([0, 1]) slot_page_indices = pam.page_indices[slot][0:2] self.assertTrue( - jnp.array_equal(slot_page_indices, expected_slot_page_indices) + np.array_equal(slot_page_indices, expected_slot_page_indices) ) - lens = jnp.asarray([0, seq_len, 0]) + lens = np.asarray([0, seq_len, 0]) pam.fill_new_pages(lens) - expected_slot_page_indices = jnp.asarray([0, 1, 2, 19]) + expected_slot_page_indices = np.asarray([0, 1, 2, 19]) slot_page_indices = pam.page_indices[slot] self.assertTrue( - jnp.array_equal(slot_page_indices, expected_slot_page_indices) + np.array_equal(slot_page_indices, expected_slot_page_indices) ) - expected_0_page_indices = jnp.asarray([19, 19, 19, 19]) + expected_0_page_indices = np.asarray([19, 19, 19, 19]) zer0_page_indices = pam.page_indices[0][0:4] - self.assertTrue(jnp.array_equal(zer0_page_indices, expected_0_page_indices)) + self.assertTrue(np.array_equal(zer0_page_indices, expected_0_page_indices)) def test_get_page_token_indices(self): env, _ = self._make_env() @@ -188,18 +188,18 @@ def test_get_page_token_indices(self): pam.reserve_pages_insert(3, 13) pam.reserve_pages_insert(0, 3) - lens = jnp.asarray([3, 8, 0, 13, 0]) + lens = np.asarray([3, 8, 0, 13, 0]) pam.fill_new_pages(lens) page_token_indices = pam.get_page_token_indices(lens) - expected_page_indices = jnp.asarray([6, 7, 5]) - expected_token_indices = jnp.asarray([3, 4, 9]) + expected_page_indices = np.asarray([6, 7, 5]) + expected_token_indices = np.asarray([3, 4, 9]) self.assertTrue( - jnp.array_equal(page_token_indices[0], expected_page_indices) + np.array_equal(page_token_indices[0], expected_page_indices) ) self.assertTrue( - jnp.array_equal(page_token_indices[1], expected_token_indices) + np.array_equal(page_token_indices[1], expected_token_indices) )