Skip to content

Fix the performance regression with ragged attention on for llama2 7b. #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions jetstream_pt/attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,33 +558,6 @@ def ragged_mha(
return out, (m, l)


def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""

bsz, _, _, head_dim = xq.shape
with jax.named_scope("attn_mat1"):
## Attention start
# scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim)
scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim)
if k_scaler is not None:
scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2]))
if mask is not None:
# if mask.shape != (1,1,16,16):
# breakpoint()
scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen)
with jax.named_scope("attn_soft"):
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
if v_scaler is not None:
scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2]))

with jax.named_scope("attn_mat2"):
# output = torch.einsum(
# "ikjm,ikml->ikjl", scores, values
# ) # (bs, n_local_heads, seqlen, head_dim)
output = torch.einsum("ikjm,ikml->ikjl", scores, values)
return output


def reshape_heads(xq, keys):
"""Reshapes the query head for GQA"""
bq, hq, tq, dq = xq.shape
Expand All @@ -607,6 +580,29 @@ def reshape_outputs(rep, o, m=None, d=None):
return o, (m, d)


def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""

bsz, _, _, head_dim = xq.shape
with jax.named_scope("attn_mat1"):
## Attention start
scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim)
if k_scaler is not None:
scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2]))
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen
with jax.named_scope("attn_soft"):
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
if v_scaler is not None:
scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2]))

with jax.named_scope("attn_mat2"):
output = torch.einsum(
"ikjm,ikml->ikjl", scores, values
) # (bs, n_local_heads, seqlen, head_dim)
return output


def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
"""The vanilla attention kernel implementation."""
xq, rep = reshape_heads(xq, keys)
Expand Down Expand Up @@ -680,7 +676,14 @@ def flash_attention(
"""Flash attention kernel."""
xq, rep = reshape_heads(xq, keys)
o, (logits_max, denominator) = _flash_attention(
xq, keys, values, k_scaler, v_scaler, mask
xq=xq,
keys=keys,
values=values,
layer=layer,
k_scaler=k_scaler,
v_scaler=v_scaler,
mask=mask,
normalize_var=normalize_var,
)
return reshape_outputs(rep, o, logits_max, denominator)

Expand Down
120 changes: 66 additions & 54 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,11 @@ def attend(xq, keys, values, local_mask=None):
# When GQA is enabled, it not necessary to expand
if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1:
true_len = 2
# xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3]))
xq = torch.nn.functional.pad(
xq, (0, 0, 0, true_len - seqlen), "constant", 0
)

if self.env.ragged_mha and seqlen == 1:
if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1:
local_output, (local_max, local_denom) = torch_xla2.interop.call_jax(
impl,
xq,
Expand All @@ -449,15 +448,28 @@ def attend(xq, keys, values, local_mask=None):
end,
ragged_batch_index,
ragged_block_index,
None, # k_scaler
None, # v_scaler
)
elif self.env.flash_attention and seqlen == 1:
with torch_xla2.default_env():
local_output, (local_max, local_denom) = self.flash_attention(
xq, keys, values, self.layer_id, mask=local_mask
xq=xq,
keys=keys,
values=values,
layer=self.layer_id,
k_scaler=None,
v_scaler=None,
mask=local_mask,
)
else:
local_output = self.dense_attention(
xq, keys, values, None, None, local_mask
xq=xq,
keys=keys,
values=values,
k_scaler=None,
v_scaler=None,
mask=local_mask,
)
local_max = None
local_denom = None
Expand All @@ -474,9 +486,6 @@ def attend(xq, keys, values, local_mask=None):
if local_denom is not None:
local_denom = local_denom[:, :, 0:seqlen, :]

# print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}")
# if local_max is not None and local_denom is not None:
# print(f"local_max {local_max.shape} local_denom {local_denom.shape}")
self.env.apply_sharding(local_output, axis=self.q_shard_axis)
return local_output, (local_max, local_denom)

Expand All @@ -486,7 +495,7 @@ def attend(xq, keys, values, local_mask=None):
# print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}")
with jax.named_scope("attn_qkv"):
existing_output, (existing_max, existing_denom) = attend(
xq, orig_keys, orig_values, mask
xq=xq, keys=orig_keys, values=orig_values, local_mask=mask
)
# Updating cache during each step still has very large impact on latency.
# For non flash attention or prefill, existing output contains everything
Expand All @@ -495,23 +504,20 @@ def attend(xq, keys, values, local_mask=None):

# For flash attention, existing output contains the existing kv cache generated logits
with jax.named_scope("attn_new_qkv"):
new_output, (new_max, new_denom) = attend(xq, xk, xv, None)
new_output, (new_max, new_denom) = attend(
xq=xq, keys=xk, values=xv, local_mask=None
)

with jax.named_scope("attn_global"):
# print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}")
# print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}")

global_sum = existing_denom * torch.exp(
existing_max
) + new_denom * torch.exp(new_max)
existing_output = (
existing_output
* existing_denom
* torch.exp(existing_max)
/ global_sum
)
new_output = new_output * new_denom * torch.exp(new_max) / global_sum
attn_out = existing_output + new_output
global_max = torch.max(existing_max, new_max)
alpha = torch.exp(existing_max - global_max)
beta = torch.exp(new_max - global_max)
global_denom = alpha * existing_denom + beta * new_denom
# global_denom = torch.where(global_denom == 0.0, 1.0, global_denom)
attn_out = (
existing_denom * alpha * existing_output
+ beta * new_output * new_denom
) / global_denom

return attn_out

Expand Down Expand Up @@ -588,8 +594,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
xq, (0, 0, 0, true_len - seqlen), "constant", 0
)

# We are not using ragged attention for prefill yet.
if self.env.ragged_mha and seqlen == 1:
if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1:
local_output, (local_max, local_denom) = torch_xla2.interop.call_jax(
impl,
xq,
Expand All @@ -606,17 +611,22 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
elif self.env.flash_attention and seqlen == 1:
with torch_xla2.default_env():
local_output, (local_max, local_denom) = self.flash_attention(
xq,
keys,
values,
self.layer_id,
k_scaler,
v_scaler,
xq=xq,
keys=keys,
values=values,
layer=self.layer_id,
k_scaler=k_scaler,
v_scaler=v_scaler,
mask=local_mask,
)
else:
local_output = self.dense_attention(
xq, keys, values, k_scaler, v_scaler, local_mask
xq=xq,
keys=keys,
values=values,
k_scaler=k_scaler,
v_scaler=v_scaler,
mask=local_mask,
)
local_max = None
local_denom = None
Expand Down Expand Up @@ -648,7 +658,12 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
) = cache.update(xk, xv, self.layer_id)
with jax.named_scope("attn_qkv"):
existing_output, (existing_max, existing_denom) = attend(
xq, orig_keys, orig_values, k_scaler, v_scaler, mask
xq=xq,
keys=orig_keys,
values=orig_values,
k_scaler=k_scaler,
v_scaler=v_scaler,
local_mask=mask,
)

# For non flash attention or prefill, existing output contains everything
Expand All @@ -663,18 +678,15 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
)

with jax.named_scope("attn_global"):
global_sum = existing_denom * torch.exp(
existing_max
) + new_denom * torch.exp(new_max)
existing_output = (
existing_output
* existing_denom
* torch.exp(existing_max)
/ global_sum
)
new_output = new_output * new_denom * torch.exp(new_max) / global_sum
attn_out = existing_output + new_output

global_max = torch.max(existing_max, new_max)
alpha = torch.exp(existing_max - global_max)
beta = torch.exp(new_max - global_max)
global_denom = alpha * existing_denom + beta * new_denom
# global_denom = torch.where(global_denom == 0.0, 1.0, global_denom)
attn_out = (
existing_denom * alpha * existing_output
+ beta * new_output * new_denom
) / global_denom
return attn_out


Expand Down Expand Up @@ -800,16 +812,16 @@ def forward(
# if cache is not None and cache.cache_k is not None:
# print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}")
output = self.attention_kernel(
xq,
xk,
xv,
mask,
xq=xq,
xk=xk,
xv=xv,
mask=mask,
# cache[self.layer_id],
cache,
start,
end,
ragged_batch_index,
ragged_block_index,
cache=cache,
start=start,
end=end,
ragged_batch_index=ragged_batch_index,
ragged_block_index=ragged_block_index,
).type_as(xq)
# print(f"output {output.shape}")
output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1)
Expand Down
39 changes: 34 additions & 5 deletions tests/test_model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
import torch_xla2

from absl.testing import parameterized

from jetstream_pt.third_party.llama import model_exportable
from jetstream_pt.third_party.llama import model_original
from jetstream_pt.third_party.gemma import model_original as gemma_orig
Expand All @@ -32,7 +34,7 @@
from . import helpers


class ModelComponentTest(unittest.TestCase):
class ModelComponentTest(parameterized.TestCase):
"""Test diff between original model and xla model for transformer,
transformer block, attention and other component in model"""

Expand Down Expand Up @@ -75,7 +77,7 @@ def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True):
if ring_buffer:
cond = jnp.logical_and(x <= pos, x >= pos - seqlen)
else:
# Left aligned buffer we postpone the cache update
# Left aligned buffer we postpone the cache update therefore mask out pos
cond = jnp.logical_and(x < pos, x >= pos - seqlen)
res = jnp.where(cond, 0, float("-inf"))
return torchjax.to_torch(res)
Expand All @@ -98,10 +100,33 @@ def _make_one_cache_for_generate(self, env, pos):
)
return cache_decode

@parameterized.named_parameters(
("ring_buffer", "ring"),
("non_ring_buffer_flash_attention", "flash"),
("non_ring_buffer_ragged_attention", "ragged"),
)
# pylint: disable-next=all
def test_attention(self):
def test_attention(self, attn_type):
torch.manual_seed(0)
env, model_arg = helpers.make_env_tiny(False)
if attn_type == "ring":
env.lazy_cache_update = False
env.ragged_mha = False
env.flash_attention = False
self.generate_cache_stacked = False
env.ring_buffer = True
elif attn_type == "flash":
env.lazy_cache_update = True
env.ragged_mha = True
env.flash_attention = True
self.generate_cache_stacked = True
env.ring_buffer = False
elif attn_type == "flash":
env.lazy_cache_update = True
env.ragged_mha = False
env.flash_attention = True
self.generate_cache_stacked = True
env.ring_buffer = False

attention_orig = model_original.Attention(model_arg)
attention_ours = layers.Attention(
Expand Down Expand Up @@ -167,10 +192,14 @@ def test_attention(self):
)
expected_out = attention_orig(*inputs_orig2)
cache_decode.input_pos = [pos] # next position to update
mask = self._generate_mask(env.cache_sequence_length, pos, seqlen)
mask = self._generate_mask(
env.cache_sequence_length, pos, seqlen, env.ring_buffer
)
mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one
freqs_cis = freqs_cis.reshape(batch, 1, -1)
input_ours2 = (x2, freqs_cis, mask, cache_decode)
start = torch.tensor([0] * batch, dtype=torch.int)
end = torch.tensor([pos] * batch, dtype=torch.int)
input_ours2 = (x2, freqs_cis, mask, cache_decode, start, end)
result_torch = helpers.call_xla_model(
attention_ours, attention_orig.state_dict(), input_ours2
)
Expand Down
Loading