Skip to content

Commit f1fa78c

Browse files
authored
Switch to NP from Jax to improve attention manager performance (#184)
Swith to np from jax to improve attention manager performance
1 parent 5b8823e commit f1fa78c

File tree

4 files changed

+37
-34
lines changed

4 files changed

+37
-34
lines changed

jetstream_pt/engine.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,10 @@ def insert_page_attention_with_reservation(
653653
decode_state: DecodeState,
654654
slot: int,
655655
) -> DecodeState:
656-
num_pages, update_indexes = (
656+
num_pages, np_update_indexes = (
657657
self.page_attention_manager.reserve_pages_insert(slot, prefix.seq_len)
658658
)
659+
update_indexes = jnp.array(np_update_indexes)
659660
_, kv_heads, _, dim = prefix.caches[0][0].shape
660661
tep_kv = jnp.zeros(
661662
(
@@ -735,10 +736,12 @@ def generate(
735736
def generate_page_attention(
736737
self, params: Any, decode_state: DecodeState
737738
) -> tuple[DecodeState, engine_api.ResultTokens]:
738-
self.page_attention_manager.fill_new_pages(decode_state.input_pos)
739-
page_token_indices = self.page_attention_manager.get_page_token_indices(
740-
decode_state.input_pos
739+
np_pos = np.asarray(decode_state.input_pos.block_until_ready())
740+
self.page_attention_manager.fill_new_pages(np_pos)
741+
np_page_token_indices = self.page_attention_manager.get_page_token_indices(
742+
np_pos
741743
)
744+
page_token_indices = jnp.asarray(np_page_token_indices)
742745
new_decode_state, result_tokens = self.generate_jit(
743746
params, decode_state, page_token_indices=page_token_indices
744747
)

jetstream_pt/page_attention_manager.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.numpy as jnp
77
import jax.sharding as jsharding
8+
import numpy as np
89

910

1011
class PageAttentionManager:
@@ -26,41 +27,39 @@ def __init__(
2627
):
2728
self.unused_pages = queue.Queue()
2829
self.batch_size = batch_size
29-
self.page_indices = jnp.full(
30+
self.page_indices = np.full(
3031
(batch_size, max_pages_per_sequence),
3132
paged_attention_total_num_pages - 1,
32-
dtype=jnp.int32,
33+
dtype=np.int32,
3334
)
34-
self.lengths = jnp.zeros(batch_size, dtype=jnp.int32)
35+
self.lengths = np.zeros(batch_size, dtype=np.int32)
3536
self.paged_attention_page_size = paged_attention_page_size
3637
self.max_pages_per_sequence = max_pages_per_sequence
3738
for i in range(paged_attention_total_num_pages):
3839
self.unused_pages.put(i, block=False)
3940

4041
# pylint: disable-next=all
41-
def reserve_pages_insert(
42-
self, slot: int, seq_len: int
43-
) -> Tuple[int, jax.Array]:
44-
self.lengths = self.lengths.at[slot].set(seq_len)
42+
def reserve_pages_insert(self, slot: int, seq_len: int):
43+
self.lengths[slot] = seq_len
4544
num_pages = (
4645
seq_len // self.paged_attention_page_size
4746
if seq_len % self.paged_attention_page_size == 0
4847
else seq_len // self.paged_attention_page_size + 1
4948
)
5049

5150
indices = [self.unused_pages.get(block=False) for _ in range(num_pages)]
52-
self.page_indices = self.page_indices.at[slot, :num_pages].set(indices)
51+
self.page_indices[slot, :num_pages] = indices
5352
return num_pages, self.page_indices[slot, :num_pages]
5453

5554
# pylint: disable-next=all
5655
def reserve_pages_decode(self, slot: int, seq_len: int):
5756
if seq_len > 0 and seq_len % self.paged_attention_page_size == 0:
5857
index = self.unused_pages.get(block=False)
5958
num_pages = seq_len // self.paged_attention_page_size
60-
self.page_indices = self.page_indices.at[slot, num_pages].set(index)
59+
self.page_indices[slot, num_pages] = index
6160

6261
# pylint: disable-next=all
63-
def fill_new_pages(self, lens: jax.Array):
62+
def fill_new_pages(self, lens):
6463
for slot in range(self.batch_size):
6564
self.reserve_pages_decode(slot, lens[slot])
6665

@@ -143,7 +142,7 @@ def insert(cache, new_entry):
143142
return caches
144143

145144
# pylint: disable-next=all
146-
def get_page_token_indices(self, lens: jax.Array) -> jax.Array:
145+
def get_page_token_indices(self, lens):
147146
# assert lens.shape == (
148147
# self.batch_size,
149148
# 1,
@@ -165,11 +164,11 @@ def get_page_token_indices(self, lens: jax.Array) -> jax.Array:
165164
token_scale_indices.append(offset + token_pos)
166165
batch_slots.append(slot)
167166
offset += self.paged_attention_page_size
168-
self.lengths = jnp.where(lens == 0, 0, lens + 1)
169-
update_page_indices = jnp.asarray(update_page_indices)
170-
token_scale_indices = jnp.asarray(token_scale_indices)
171-
batch_slots = jnp.asarray(batch_slots)
172-
return jnp.stack(
167+
self.lengths = np.where(lens == 0, 0, lens + 1)
168+
update_page_indices = np.asarray(update_page_indices)
169+
token_scale_indices = np.asarray(token_scale_indices)
170+
batch_slots = np.asarray(batch_slots)
171+
return np.stack(
173172
(
174173
update_page_indices,
175174
token_scale_indices,

tests/test_kv_cache_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ def _insert_prefill(seq_len, dim, slot):
9393
decode_caches = _insert_prefill(8, 2, 1)
9494
decode_caches = _insert_prefill(13, 2, 3)
9595

96-
lens = jnp.asarray([3, 8, 0, 13, 0])
96+
lens = np.asarray([3, 8, 0, 13, 0])
9797
pam.fill_new_pages(lens)
98-
page_token_indices = pam.get_page_token_indices(lens)
98+
np_page_token_indices = pam.get_page_token_indices(lens)
99+
page_token_indices = jnp.asarray(np_page_token_indices)
99100
page_token_indices = torchjax.to_torch(page_token_indices)
100101

101102
caches_obj = [

tests/test_page_attention.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -157,23 +157,23 @@ def test_reserve_pages_decode(self):
157157
slot = 1
158158
seq_len = 8
159159
pam.reserve_pages_insert(slot, seq_len)
160-
expected_slot_page_indices = jnp.asarray([0, 1])
160+
expected_slot_page_indices = np.asarray([0, 1])
161161
slot_page_indices = pam.page_indices[slot][0:2]
162162
self.assertTrue(
163-
jnp.array_equal(slot_page_indices, expected_slot_page_indices)
163+
np.array_equal(slot_page_indices, expected_slot_page_indices)
164164
)
165165

166-
lens = jnp.asarray([0, seq_len, 0])
166+
lens = np.asarray([0, seq_len, 0])
167167
pam.fill_new_pages(lens)
168-
expected_slot_page_indices = jnp.asarray([0, 1, 2, 19])
168+
expected_slot_page_indices = np.asarray([0, 1, 2, 19])
169169
slot_page_indices = pam.page_indices[slot]
170170
self.assertTrue(
171-
jnp.array_equal(slot_page_indices, expected_slot_page_indices)
171+
np.array_equal(slot_page_indices, expected_slot_page_indices)
172172
)
173173

174-
expected_0_page_indices = jnp.asarray([19, 19, 19, 19])
174+
expected_0_page_indices = np.asarray([19, 19, 19, 19])
175175
zer0_page_indices = pam.page_indices[0][0:4]
176-
self.assertTrue(jnp.array_equal(zer0_page_indices, expected_0_page_indices))
176+
self.assertTrue(np.array_equal(zer0_page_indices, expected_0_page_indices))
177177

178178
def test_get_page_token_indices(self):
179179
env, _ = self._make_env()
@@ -188,18 +188,18 @@ def test_get_page_token_indices(self):
188188
pam.reserve_pages_insert(3, 13)
189189
pam.reserve_pages_insert(0, 3)
190190

191-
lens = jnp.asarray([3, 8, 0, 13, 0])
191+
lens = np.asarray([3, 8, 0, 13, 0])
192192
pam.fill_new_pages(lens)
193193

194194
page_token_indices = pam.get_page_token_indices(lens)
195195

196-
expected_page_indices = jnp.asarray([6, 7, 5])
197-
expected_token_indices = jnp.asarray([3, 4, 9])
196+
expected_page_indices = np.asarray([6, 7, 5])
197+
expected_token_indices = np.asarray([3, 4, 9])
198198
self.assertTrue(
199-
jnp.array_equal(page_token_indices[0], expected_page_indices)
199+
np.array_equal(page_token_indices[0], expected_page_indices)
200200
)
201201
self.assertTrue(
202-
jnp.array_equal(page_token_indices[1], expected_token_indices)
202+
np.array_equal(page_token_indices[1], expected_token_indices)
203203
)
204204

205205

0 commit comments

Comments
 (0)