-
Notifications
You must be signed in to change notification settings - Fork 17
Optimize cache update. #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: | ||
"""Ragged multi query attention.""" | ||
with jax.named_scope("ragged_mqa"): | ||
batch_size, num_heads, head_dim = q.shape | ||
seq_len = k.shape[1] | ||
batch_size, time, head_dim = q.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to change num_heads to time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After vmap, the number of heads dimension are gone. So it's indeed the sequence length dimension, which we can also call it "time".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel "Time" is kind of misleading variable name here. Can we use q_seq_len instead of time?
If we are only using ragged attention in decode sate, do we need this query seq len as it always be 1?
seq_len = k.shape[-2] | ||
|
||
stacked = False | ||
if k.ndim == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share an example that the n.ndim is 5 (with block quantization)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why the block the quantization matters. If the cache is stacked, it will have layer, batch, number of heads, time, head dim these 5 dimensions no matter if it's quantized or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarification!
jetstream_pt/attention_kernel.py
Outdated
normalize_var: bool, | ||
quantized: bool, | ||
): | ||
"""Pallas kernel for flash attention.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace "flash" with "ragged"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated!
def run(): | ||
q = q_ref[...].astype(jnp.float32) | ||
k = k_ref[...].astype(jnp.float32) | ||
v = v_ref[...].astype(jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to convert to fp32? can we use bf16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the arithmetic operation only supports f32 and it reports error if force to be bf16. Confirmed with XLA team about the constraint: b/340263269 and b/341729764.
return layer_ref[0], b_next, i_next, 0 | ||
return b_next, i_next, 0 | ||
|
||
def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share why i_next are assigned to different position between kv_index_map and kv_scale_index_map?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i_next doesn't get different value. It's in different position because the scale has the shape of batch, 1, kv_length. And the grid[1] applied to the last dimension here. That's why we give the i_next in this dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combine with precompute_ragged_block_indices, for a giving decode: start = jnp.asarray([11, 0, 10])
input_pos = jnp.asarray([15, 9, 8]), suppose cache_len = 16
block_size = 4, can you share what are expected kv index map?
@@ -310,15 +586,78 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): | |||
return output | |||
|
|||
|
|||
def flash_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flash attention use block q, k, v to do tiling compute. Is this function just an vanilla attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flash attention has the capability of blockwise compute the local softmax, which is exactly what we are doing here. In terms of how to divide the block, it's up to the user. We leveraged this to divide the attention calculation to existing cache and new cache. So this is indeed the flash attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any up function to call this local attention? If this function is only for the each for loop q_block, v_block and k_block, should we rename it as block_attention?
In generate Flash attention need to dynamic select the max and scale the softmax. Below code are like a flash attention from the ragged_mqa... funciton:
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
Correct me If i'm wrong.
self.input_pos, | ||
) | ||
|
||
def update(self, key, value, layer_id: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great implementation! But in general, I feel the logic is too complex to maintain. Can we have different KVCacheGenerate class to handle ring_buffer, ragged attention and stacked or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about merging the Int8KVCacheGenerate and KVCacheGenerate cuz there are a lot of shared code. I can combine all 4 additional flags (lazy_cache_update, generate_cache_stacked, new_cache_stacked, flash_attention) into 1 to simplify the logic, cuz these flags only helps for my experimentation. It should not be exposed to user. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. My main concerns is current code logic is too complex for read and maintain. The cache manager is very straightforward implementation before, but right now the logic is very complex. Let's only keep the most optimized code in the repo.
required=False, | ||
) | ||
flags.DEFINE_bool( | ||
"generate_cache_stacked", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are benefits of cache_stacked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It reduces the DMA transfer time. Minimize the number of DMA transfer helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the XLA handles cache insertion for all the layers much more efficiently than iterating over layer dimension by user.
"Whether to enable ring buffer", | ||
required=False, | ||
) | ||
flags.DEFINE_bool( | ||
"flash_attention", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you plan to enable flash_attention by itself without ragged attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, ragged attention has better performance than flash attention. As I indicated in the description, it only takes effect at test mode. Which means user cannot directly enable it in either interactive, offline or server mode.
input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), | ||
output_specs=(qkv_pspec, (others_pspec, others_pspec)), | ||
sharding_axis=self.shard_axis, | ||
input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct me if I'm wrong, the ragged_attention_new doesn't support generate_cache_stacked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ragged_attention_new is for new cache in the current step, which has length of 1, so there is nothing to stack.
15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention? |
When cache is left aligned + unstacked, the data transfer overhead is non neglegible. I tried flash attention, which is 90ms for each step. These overhead has nothing to do with which attention you are using. |
jetstream_pt/cache_manager.py
Outdated
self.new_v_scaler, | ||
] | ||
( | ||
self.cache_k._elem, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
._elem
seems spurious; as _elem
is already a jax array.
So it's either: x._elem = foo(jax_array_inputs)
OR x = call_jax(foo, torch_tensor_inputs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, decided to remove _elem since it's violating lint anyway.
jetstream_pt/layers.py
Outdated
@@ -367,9 +367,25 @@ def apply_rotary_emb( | |||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | |||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep).""" | |||
|
|||
bs, n_kv_heads, slen, head_dim = x.shape | |||
bs, n_kv_heads, slen, head_dim = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bs, n_kv_heads, slen, head_dim, *_ = x.shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, should be *_, bs, n_kv_heads, slen, head_dim.= x.shape ?
jetstream_pt/layers.py
Outdated
x.shape[-2], | ||
x.shape[-1], | ||
) | ||
if x.ndim == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stacked = x.ndim == 5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better! Thanks!
if n_rep == 1: | ||
return x | ||
if stacked: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or just put the ndim == 5 here and remove the stacked var
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably prefer to keep stacked to make the code more clear.
Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1 |
there is some updates on deps/Jetstream is that intentional? |
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: | ||
"""Ragged multi query attention.""" | ||
with jax.named_scope("ragged_mqa"): | ||
batch_size, num_heads, head_dim = q.shape | ||
seq_len = k.shape[1] | ||
batch_size, time, head_dim = q.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel "Time" is kind of misleading variable name here. Can we use q_seq_len instead of time?
If we are only using ragged attention in decode sate, do we need this query seq len as it always be 1?
seq_len = k.shape[-2] | ||
|
||
stacked = False | ||
if k.ndim == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarification!
seq_len = k.shape[-2] | ||
|
||
stacked = False | ||
if k.ndim == 4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vmap reduce the head dem, so stacked ndim become 4 from 5. Correct me If'm I'm wrong.
I'm wondering, do we need a vmp in ragged attention? The shmap did first reduction which reduce head dim from 32 to 4 (take llama2 7b and v5e-8 as exmple), can we process 4 head in a single process? Is there performance regression if we use multiple head in ragged attention compared with single head attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a must. By reducing the number of heads dimension to 1 the MHA becomes MQA. That's just for compatibility.
return layer_ref[0], b_next, i_next, 0 | ||
return b_next, i_next, 0 | ||
|
||
def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combine with precompute_ragged_block_indices, for a giving decode: start = jnp.asarray([11, 0, 10])
input_pos = jnp.asarray([15, 9, 8]), suppose cache_len = 16
block_size = 4, can you share what are expected kv index map?
jnp.array([layer]), | ||
start, | ||
end, | ||
end, # line_end, not actually used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying!
@@ -310,15 +586,78 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): | |||
return output | |||
|
|||
|
|||
def flash_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any up function to call this local attention? If this function is only for the each for loop q_block, v_block and k_block, should we rename it as block_attention?
In generate Flash attention need to dynamic select the max and scale the softmax. Below code are like a flash attention from the ragged_mqa... funciton:
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
Correct me If i'm wrong.
self.input_pos, | ||
) | ||
|
||
def update(self, key, value, layer_id: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. My main concerns is current code logic is too complex for read and maintain. The cache manager is very straightforward implementation before, but right now the logic is very complex. Let's only keep the most optimized code in the repo.
There are new lint error, can you fix it? |
… ring buffer support then fix the mask. Int8 updates also included but not tested.
…o the end of existing cache attention.
…on. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc.
…run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run.
4b6d862
to
595ead2
Compare
* Fix TPU head resource name for v4 and v5e * fix format
* add xla2 fix * update jax version * revert jax TPU version
…rom merge; Fix lints;
Fixed all the lint issues. |
I will remove precompute_ragged_block_indices, clear up the ragged attention impl (e.g. remove the one for the ring buffer) and simplify the flags for non ring buffer case therefore simplify the cache manager in the subsequent PR. Will push this PR first since it's been standing alone for a while. |
We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The left aligned cache also improves the insert efficiency. The overall benchmark performance is boosted by 15%.