Skip to content

Switch to NP from Jax to improve attention manager performance #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,10 @@ def insert_page_attention_with_reservation(
decode_state: DecodeState,
slot: int,
) -> DecodeState:
num_pages, update_indexes = (
num_pages, np_update_indexes = (
self.page_attention_manager.reserve_pages_insert(slot, prefix.seq_len)
)
update_indexes = jnp.array(np_update_indexes)
_, kv_heads, _, dim = prefix.caches[0][0].shape
tep_kv = jnp.zeros(
(
Expand Down Expand Up @@ -735,10 +736,12 @@ def generate(
def generate_page_attention(
self, params: Any, decode_state: DecodeState
) -> tuple[DecodeState, engine_api.ResultTokens]:
self.page_attention_manager.fill_new_pages(decode_state.input_pos)
page_token_indices = self.page_attention_manager.get_page_token_indices(
decode_state.input_pos
np_pos = np.asarray(decode_state.input_pos.block_until_ready())
self.page_attention_manager.fill_new_pages(np_pos)
np_page_token_indices = self.page_attention_manager.get_page_token_indices(
np_pos
)
page_token_indices = jnp.asarray(np_page_token_indices)
new_decode_state, result_tokens = self.generate_jit(
params, decode_state, page_token_indices=page_token_indices
)
Expand Down
31 changes: 15 additions & 16 deletions jetstream_pt/page_attention_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
import jax.sharding as jsharding
import numpy as np


class PageAttentionManager:
Expand All @@ -26,41 +27,39 @@ def __init__(
):
self.unused_pages = queue.Queue()
self.batch_size = batch_size
self.page_indices = jnp.full(
self.page_indices = np.full(
(batch_size, max_pages_per_sequence),
paged_attention_total_num_pages - 1,
dtype=jnp.int32,
dtype=np.int32,
)
self.lengths = jnp.zeros(batch_size, dtype=jnp.int32)
self.lengths = np.zeros(batch_size, dtype=np.int32)
self.paged_attention_page_size = paged_attention_page_size
self.max_pages_per_sequence = max_pages_per_sequence
for i in range(paged_attention_total_num_pages):
self.unused_pages.put(i, block=False)

# pylint: disable-next=all
def reserve_pages_insert(
self, slot: int, seq_len: int
) -> Tuple[int, jax.Array]:
self.lengths = self.lengths.at[slot].set(seq_len)
def reserve_pages_insert(self, slot: int, seq_len: int):
self.lengths[slot] = seq_len
num_pages = (
seq_len // self.paged_attention_page_size
if seq_len % self.paged_attention_page_size == 0
else seq_len // self.paged_attention_page_size + 1
)

indices = [self.unused_pages.get(block=False) for _ in range(num_pages)]
self.page_indices = self.page_indices.at[slot, :num_pages].set(indices)
self.page_indices[slot, :num_pages] = indices
return num_pages, self.page_indices[slot, :num_pages]

# pylint: disable-next=all
def reserve_pages_decode(self, slot: int, seq_len: int):
if seq_len > 0 and seq_len % self.paged_attention_page_size == 0:
index = self.unused_pages.get(block=False)
num_pages = seq_len // self.paged_attention_page_size
self.page_indices = self.page_indices.at[slot, num_pages].set(index)
self.page_indices[slot, num_pages] = index

# pylint: disable-next=all
def fill_new_pages(self, lens: jax.Array):
def fill_new_pages(self, lens):
for slot in range(self.batch_size):
self.reserve_pages_decode(slot, lens[slot])

Expand Down Expand Up @@ -143,7 +142,7 @@ def insert(cache, new_entry):
return caches

# pylint: disable-next=all
def get_page_token_indices(self, lens: jax.Array) -> jax.Array:
def get_page_token_indices(self, lens):
# assert lens.shape == (
# self.batch_size,
# 1,
Expand All @@ -165,11 +164,11 @@ def get_page_token_indices(self, lens: jax.Array) -> jax.Array:
token_scale_indices.append(offset + token_pos)
batch_slots.append(slot)
offset += self.paged_attention_page_size
self.lengths = jnp.where(lens == 0, 0, lens + 1)
update_page_indices = jnp.asarray(update_page_indices)
token_scale_indices = jnp.asarray(token_scale_indices)
batch_slots = jnp.asarray(batch_slots)
return jnp.stack(
self.lengths = np.where(lens == 0, 0, lens + 1)
update_page_indices = np.asarray(update_page_indices)
token_scale_indices = np.asarray(token_scale_indices)
batch_slots = np.asarray(batch_slots)
return np.stack(
(
update_page_indices,
token_scale_indices,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ def _insert_prefill(seq_len, dim, slot):
decode_caches = _insert_prefill(8, 2, 1)
decode_caches = _insert_prefill(13, 2, 3)

lens = jnp.asarray([3, 8, 0, 13, 0])
lens = np.asarray([3, 8, 0, 13, 0])
pam.fill_new_pages(lens)
page_token_indices = pam.get_page_token_indices(lens)
np_page_token_indices = pam.get_page_token_indices(lens)
page_token_indices = jnp.asarray(np_page_token_indices)
page_token_indices = torchjax.to_torch(page_token_indices)

caches_obj = [
Expand Down
24 changes: 12 additions & 12 deletions tests/test_page_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,23 @@ def test_reserve_pages_decode(self):
slot = 1
seq_len = 8
pam.reserve_pages_insert(slot, seq_len)
expected_slot_page_indices = jnp.asarray([0, 1])
expected_slot_page_indices = np.asarray([0, 1])
slot_page_indices = pam.page_indices[slot][0:2]
self.assertTrue(
jnp.array_equal(slot_page_indices, expected_slot_page_indices)
np.array_equal(slot_page_indices, expected_slot_page_indices)
)

lens = jnp.asarray([0, seq_len, 0])
lens = np.asarray([0, seq_len, 0])
pam.fill_new_pages(lens)
expected_slot_page_indices = jnp.asarray([0, 1, 2, 19])
expected_slot_page_indices = np.asarray([0, 1, 2, 19])
slot_page_indices = pam.page_indices[slot]
self.assertTrue(
jnp.array_equal(slot_page_indices, expected_slot_page_indices)
np.array_equal(slot_page_indices, expected_slot_page_indices)
)

expected_0_page_indices = jnp.asarray([19, 19, 19, 19])
expected_0_page_indices = np.asarray([19, 19, 19, 19])
zer0_page_indices = pam.page_indices[0][0:4]
self.assertTrue(jnp.array_equal(zer0_page_indices, expected_0_page_indices))
self.assertTrue(np.array_equal(zer0_page_indices, expected_0_page_indices))

def test_get_page_token_indices(self):
env, _ = self._make_env()
Expand All @@ -188,18 +188,18 @@ def test_get_page_token_indices(self):
pam.reserve_pages_insert(3, 13)
pam.reserve_pages_insert(0, 3)

lens = jnp.asarray([3, 8, 0, 13, 0])
lens = np.asarray([3, 8, 0, 13, 0])
pam.fill_new_pages(lens)

page_token_indices = pam.get_page_token_indices(lens)

expected_page_indices = jnp.asarray([6, 7, 5])
expected_token_indices = jnp.asarray([3, 4, 9])
expected_page_indices = np.asarray([6, 7, 5])
expected_token_indices = np.asarray([3, 4, 9])
self.assertTrue(
jnp.array_equal(page_token_indices[0], expected_page_indices)
np.array_equal(page_token_indices[0], expected_page_indices)
)
self.assertTrue(
jnp.array_equal(page_token_indices[1], expected_token_indices)
np.array_equal(page_token_indices[1], expected_token_indices)
)


Expand Down
Loading