Skip to content

Commit 2cf3673

Browse files
committed
Fix the performance regression with ragged attention on for llama2 7b model.
1 parent 0495312 commit 2cf3673

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

jetstream_pt/environment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ class JetEngineEnvironmentData:
124124
# The ratio between query heads and kv heads
125125
n_reps: int = 0
126126

127+
127128
# pylint: disable-next=all
128129
class JetEngineEnvironment:
129-
# pylint: disable-next=all
130+
# pylint: disable-next=all
130131
def __init__(self, data: JetEngineEnvironmentData):
131132
self._data = data
132133

jetstream_pt/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def attend(xq, keys, values, local_mask=None):
438438
xq, (0, 0, 0, true_len - seqlen), "constant", 0
439439
)
440440

441-
if self.env.ragged_mha and seqlen == 1:
441+
if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1:
442442
local_output, (local_max, local_denom) = torch_xla2.interop.call_jax(
443443
impl,
444444
xq,
@@ -589,7 +589,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None):
589589
)
590590

591591
# We are not using ragged attention for prefill yet.
592-
if self.env.ragged_mha and seqlen == 1:
592+
if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1:
593593
local_output, (local_max, local_denom) = torch_xla2.interop.call_jax(
594594
impl,
595595
xq,

tests/test_run_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def reset_flags(self):
4545

4646
def setup(self):
4747
"""Setup."""
48-
# pylint: disable-next=all
48+
# pylint: disable-next=all
4949
from run_server import flags
5050

5151
f = flags.FLAGS

0 commit comments

Comments
 (0)