Skip to content

Commit ef351e9

Browse files
authored
Merge pull request meta-llama#900 from flu0r1ne/main
Fix key-value caching for seqlen != 1 (Issue meta-llama#899)
2 parents 4835a30 + cd0719d commit ef351e9

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

llama/model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,12 @@ def forward(
289289
values = self.cache_v[:bsz, : start_pos + seqlen]
290290

291291
# repeat k/v heads if n_kv_heads < n_heads
292-
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
293-
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
292+
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
293+
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
294294

295295
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
296-
keys = keys.transpose(1, 2)
297-
values = values.transpose(1, 2)
296+
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
297+
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
298298
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
299299
if mask is not None:
300300
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
@@ -474,9 +474,19 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
474474
mask = None
475475
if seqlen > 1:
476476
mask = torch.full(
477-
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
477+
(seqlen, seqlen), float("-inf"), device=tokens.device
478478
)
479-
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
479+
480+
mask = torch.triu(mask, diagonal=1)
481+
482+
# When performing key-value caching, we compute the attention scores
483+
# only for the new sequence. Thus, the matrix of scores is of size
484+
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
485+
# j > cache_len + i, since row i corresponds to token cache_len + i.
486+
mask = torch.hstack([
487+
torch.zeros((seqlen, start_pos), device=tokens.device),
488+
mask
489+
]).type_as(h)
480490

481491
for layer in self.layers:
482492
h = layer(h, start_pos, freqs_cis, mask)

0 commit comments

Comments
 (0)