@@ -197,14 +197,14 @@ class SDPA(nn.Module):
197
197
def __init__ (
198
198
self ,
199
199
kv_cache : KVCache ,
200
- mask ,
201
200
dim : int ,
201
+ head_dim : int ,
202
202
n_rep : int ,
203
203
):
204
204
super ().__init__ ()
205
205
self .kv_cache = kv_cache
206
- self .mask = mask
207
206
self .dim = dim
207
+ self .head_dim = head_dim
208
208
self .n_rep = n_rep
209
209
210
210
def forward (
@@ -215,17 +215,18 @@ def forward(
215
215
v : torch .Tensor ,
216
216
bsz ,
217
217
seqlen ,
218
+ mask : torch .Tensor ,
218
219
) -> torch .Tensor :
219
220
q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
220
221
k = k .transpose (1 , 2 )
221
222
v = v .transpose (1 , 2 )
222
223
223
224
k , v = self .kv_cache .update (input_pos , k , v )
224
- mask = self . mask [None , None , input_pos ]
225
+ attn_mask = mask [None , None , input_pos ]
225
226
226
227
k = k .repeat_interleave (self .n_rep , dim = 1 )
227
228
v = v .repeat_interleave (self .n_rep , dim = 1 )
228
- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
229
+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
229
230
230
231
return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
231
232
@@ -271,10 +272,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
271
272
not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
272
273
)
273
274
self .SDPA = SDPA (
274
- self .kv_cache ,
275
- self .mask ,
276
- self .dim ,
277
- self .n_rep ,
275
+ kv_cache = self .kv_cache ,
276
+ dim = self .dim ,
277
+ head_dim = self .head_dim ,
278
+ n_rep = self .n_rep ,
278
279
)
279
280
280
281
def forward (
@@ -298,7 +299,7 @@ def forward(
298
299
299
300
if self .use_kv_cache :
300
301
assert input_pos is not None
301
- output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
302
+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self . mask )
302
303
return self .wo (output )
303
304
304
305
q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
0 commit comments