@@ -289,12 +289,12 @@ def forward(
289
289
values = self .cache_v [:bsz , : start_pos + seqlen ]
290
290
291
291
# repeat k/v heads if n_kv_heads < n_heads
292
- keys = repeat_kv (keys , self .n_rep ) # (bs, seqlen, n_local_heads, head_dim)
293
- values = repeat_kv (values , self .n_rep ) # (bs, seqlen, n_local_heads, head_dim)
292
+ keys = repeat_kv (keys , self .n_rep ) # (bs, cache_len + seqlen, n_local_heads, head_dim)
293
+ values = repeat_kv (values , self .n_rep ) # (bs, cache_len + seqlen, n_local_heads, head_dim)
294
294
295
295
xq = xq .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
296
- keys = keys .transpose (1 , 2 )
297
- values = values .transpose (1 , 2 )
296
+ keys = keys .transpose (1 , 2 ) # (bs, n_local_heads, cache_len + seqlen, head_dim)
297
+ values = values .transpose (1 , 2 ) # (bs, n_local_heads, cache_len + seqlen, head_dim)
298
298
scores = torch .matmul (xq , keys .transpose (2 , 3 )) / math .sqrt (self .head_dim )
299
299
if mask is not None :
300
300
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
@@ -474,9 +474,19 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
474
474
mask = None
475
475
if seqlen > 1 :
476
476
mask = torch .full (
477
- (1 , 1 , seqlen , seqlen ), float ("-inf" ), device = tokens .device
477
+ (seqlen , seqlen ), float ("-inf" ), device = tokens .device
478
478
)
479
- mask = torch .triu (mask , diagonal = start_pos + 1 ).type_as (h )
479
+
480
+ mask = torch .triu (mask , diagonal = 1 )
481
+
482
+ # When performing key-value caching, we compute the attention scores
483
+ # only for the new sequence. Thus, the matrix of scores is of size
484
+ # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
485
+ # j > cache_len + i, since row i corresponds to token cache_len + i.
486
+ mask = torch .hstack ([
487
+ torch .zeros ((seqlen , start_pos ), device = tokens .device ),
488
+ mask
489
+ ]).type_as (h )
480
490
481
491
for layer in self .layers :
482
492
h = layer (h , start_pos , freqs_cis , mask )
0 commit comments