Skip to content

Commit 33348d2

Browse files
authored
Support End To End PagedAttention in JetStream (#180)
* Support E2E PageAttention in JetStream * change to right jetstream version * Fix page index shape * Fix decode len shape * Fix variable name * use __getattr__ * fix env variable
1 parent ec4ac8f commit 33348d2

11 files changed

+1006
-101
lines changed

jetstream_pt/attention_kernel.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1+
from collections.abc import Callable
12
import functools
23
import math
4+
from typing import Any
35

46
import jax
57
import jax.numpy as jnp
68
from jax.experimental import pallas as pl
79
from jax.experimental.pallas import tpu as pltpu
10+
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
811
from jax.experimental.shard_map import shard_map
9-
12+
import numpy as np
1013
import torch
1114
import torch.nn.functional as F
15+
from jetstream_pt import torchjax
1216

13-
import numpy as np
1417

1518
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
19+
P = jax.sharding.PartitionSpec
1620

1721

1822
def ragged_flash_attention_kernel(
@@ -735,3 +739,52 @@ def __call__(
735739
k_scaler,
736740
v_scaler,
737741
)
742+
743+
744+
def shard_kv_heads(
745+
paged_attention_impl: Callable[..., Any],
746+
mesh: jax.sharding.Mesh,
747+
kv_head_mesh_axis_name: str,
748+
):
749+
"""Shard map on kv head."""
750+
in_specs = (
751+
P(None, kv_head_mesh_axis_name, None), # q
752+
P(kv_head_mesh_axis_name, None, None, None), # k
753+
P(kv_head_mesh_axis_name, None, None, None), # v
754+
P(), # lengths
755+
P(), # page_indices
756+
)
757+
758+
out_specs = P(None, kv_head_mesh_axis_name, None) # q
759+
760+
return jax.jit(
761+
shard_map(
762+
paged_attention_impl,
763+
mesh=mesh,
764+
in_specs=in_specs,
765+
out_specs=out_specs,
766+
check_rep=False,
767+
)
768+
)
769+
770+
771+
def call_paged_attention(env, xq, keys, values, seq_lens, page_indices):
772+
"""Paged attention kernel."""
773+
xq, keys, values, seq_lens, page_indices = torchjax.from_torch(
774+
(xq, keys, values, seq_lens, page_indices)
775+
)
776+
paged_attention_impl = functools.partial(
777+
paged_attention,
778+
pages_per_compute_block=env.block_size // env.paged_attention_page_size,
779+
# mask_value=float("-inf")
780+
)
781+
sharded_paged_attention_impl = shard_kv_heads(
782+
paged_attention_impl,
783+
env.mesh,
784+
kv_head_mesh_axis_name="x",
785+
)
786+
output = sharded_paged_attention_impl(
787+
xq, keys, values, seq_lens, page_indices
788+
)
789+
790+
return torchjax.to_torch(output)

jetstream_pt/cache_manager.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch_xla2
2020

2121
from jetstream_pt import torchjax
22+
from jetstream_pt.page_attention_manager import PageAttentionManager
2223

2324

2425
# pylint: disable-next=all
@@ -663,18 +664,21 @@ def __init__(
663664
self,
664665
cache_k: torch.Tensor, # previous cache
665666
cache_v: torch.Tensor, # previous cache
667+
page_attention_manager: PageAttentionManager,
666668
page_token_indices: torch.Tensor, # page and token indices for the cache
667669
sharding,
668670
env=None,
669671
):
670672
super().__init__()
671673
self.cache_k = cache_k
672674
self.cache_v = cache_v
675+
self.page_attention_manager = page_attention_manager
673676
self.page_token_indices = page_token_indices
674677
self.sharding = sharding
675678
self.env = env
679+
self.stacked = False
676680

677-
def update(self, key, value):
681+
def update(self, key, value, layer_id=0):
678682
"""Update kv cache"""
679683
keyj, valuej, page_token_indicesj = torchjax.from_torch(
680684
(key, value, self.page_token_indices)
@@ -683,32 +687,38 @@ def update(self, key, value):
683687
def _update(cache, x):
684688
x = x.squeeze(2).transpose((1, 0, 2))
685689
x = x[:, page_token_indicesj[2], :]
686-
head, _, page_size, dim = cache.shape
690+
head, _, paged_attention_page_size, dim = cache.shape
687691
selected_cache = cache[:, page_token_indicesj[0], :, :]
688692
selected_cache = selected_cache.reshape((head, -1, dim))
689693

690694
selected_cache = selected_cache.at[:, page_token_indicesj[1], :].set(x)
691-
selected_cache = selected_cache.reshape((head, -1, page_size, dim))
695+
selected_cache = selected_cache.reshape(
696+
(head, -1, paged_attention_page_size, dim)
697+
)
692698

693699
cache = cache.at[:, page_token_indicesj[0], :, :].set(selected_cache)
694700
return cache
695701

696702
# pylint: disable-next=all
697703
self.cache_k._elem = _update(self.cache_k._elem, keyj)
698704
# pylint: disable-next=all
699-
self.cache_k._elem = _update(self.cache_v._elem, valuej)
705+
self.cache_v._elem = _update(self.cache_v._elem, valuej)
700706
return self.cache_k, self.cache_v
701707

702708
def state(self):
703709
"""Get kv cache state"""
704710
# pylint: disable-next=all
705-
return self.cache_k.jax(), self.cache_v.jax()
711+
return torchjax.from_torch((self.cache_k, self.cache_v))
712+
713+
def finalize(self):
714+
"""Do nothing now"""
715+
return
706716

707717
@classmethod
708-
def empty(cls, shape, device, bf16_enable, env):
718+
def empty(cls, shape, device, env):
709719
"""Create empty kv caches"""
710-
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
720+
default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
711721
k = jnp.zeros(shape, device=device, dtype=default_dtype)
712722
v = jnp.zeros(shape, device=device, dtype=default_dtype)
713723
k, v = torchjax.to_torch((k, v))
714-
return cls(k, v, None, device, env=env)
724+
return cls(k, v, None, None, device, env=env)

jetstream_pt/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@
139139
"size of top k used when sampling next token",
140140
)
141141

142+
flags.DEFINE_integer(
143+
"paged_attention_total_num_pages",
144+
0,
145+
"total number of pages per layer for page attention",
146+
)
147+
148+
flags.DEFINE_integer(
149+
"paged_attention_page_size",
150+
64,
151+
"page size per page",
152+
)
153+
142154

143155
def create_quantization_config_from_flags():
144156
"""Create Quantization Config from cmd flags"""
@@ -213,6 +225,8 @@ def create_engine_from_config_flags():
213225
generate_cache_stacked=FLAGS.generate_cache_stacked,
214226
new_cache_stacked=FLAGS.new_cache_stacked,
215227
lazy_cache_update=FLAGS.lazy_cache_update,
228+
paged_attention_total_num_pages=FLAGS.paged_attention_total_num_pages,
229+
paged_attention_page_size=FLAGS.paged_attention_page_size,
216230
)
217231

218232
print("Initialize engine", time.perf_counter() - start)

0 commit comments

Comments
 (0)