@@ -193,6 +193,95 @@ def update(
193
193
return k_out , v_out
194
194
195
195
196
+ class SDPA (nn .Module ):
197
+ def __init__ (
198
+ self ,
199
+ kv_cache : KVCache ,
200
+ mask ,
201
+ use_sdpa_with_kv_cache_op : bool ,
202
+ dim : int ,
203
+ n_rep : int ,
204
+ ):
205
+ super ().__init__ ()
206
+ self .kv_cache = kv_cache
207
+ self .mask = mask
208
+ self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
209
+ self .dim = dim
210
+ self .n_rep = n_rep
211
+
212
+ def forward (
213
+ self ,
214
+ input_pos : torch .Tensor ,
215
+ q : torch .Tensor ,
216
+ k : torch .Tensor ,
217
+ v : torch .Tensor ,
218
+ bsz ,
219
+ seqlen ,
220
+ ) -> torch .Tensor :
221
+ if not self .use_sdpa_with_kv_cache_op :
222
+ return self ._forward_default (
223
+ input_pos ,
224
+ q ,
225
+ k ,
226
+ v ,
227
+ bsz ,
228
+ seqlen ,
229
+ )
230
+ else :
231
+ return self ._forward_custom (
232
+ input_pos ,
233
+ q ,
234
+ k ,
235
+ v ,
236
+ bsz ,
237
+ seqlen ,
238
+ )
239
+
240
+ def _forward_custom (
241
+ self ,
242
+ input_pos : torch .Tensor ,
243
+ q : torch .Tensor ,
244
+ k : torch .Tensor ,
245
+ v : torch .Tensor ,
246
+ bsz ,
247
+ seqlen ,
248
+ ):
249
+ from .custom_ops import sdpa_with_kv_cache # noqa
250
+
251
+ output = torch .ops .llama .sdpa_with_kv_cache (
252
+ q ,
253
+ k ,
254
+ v ,
255
+ self .kv_cache .k_cache ,
256
+ self .kv_cache .v_cache ,
257
+ input_pos [- 1 ].item (),
258
+ seqlen ,
259
+ )
260
+ return output .view (bsz , seqlen , self .dim )
261
+
262
+ def _forward_default (
263
+ self ,
264
+ input_pos : torch .Tensor ,
265
+ q : torch .Tensor ,
266
+ k : torch .Tensor ,
267
+ v : torch .Tensor ,
268
+ bsz ,
269
+ seqlen ,
270
+ ) -> torch .Tensor :
271
+ q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
272
+ k = k .transpose (1 , 2 )
273
+ v = v .transpose (1 , 2 )
274
+
275
+ k , v = self .kv_cache .update (input_pos , k , v )
276
+ mask = self .mask [None , None , input_pos ]
277
+
278
+ k = k .repeat_interleave (self .n_rep , dim = 1 )
279
+ v = v .repeat_interleave (self .n_rep , dim = 1 )
280
+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
281
+
282
+ return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
283
+
284
+
196
285
class Attention (nn .Module ):
197
286
def __init__ (self , args : ModelArgs , layer_id : int ):
198
287
super ().__init__ ()
@@ -213,7 +302,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
213
302
self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
214
303
self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
215
304
216
- self .use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache_op
217
305
self .layer_id = layer_id
218
306
219
307
causal_mask = torch .tril (
@@ -234,6 +322,13 @@ def __init__(self, args: ModelArgs, layer_id: int):
234
322
self .head_dim ,
235
323
not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
236
324
)
325
+ self .SDPA = SDPA (
326
+ self .kv_cache ,
327
+ self .mask ,
328
+ args .use_sdpa_with_kv_cache_op ,
329
+ self .dim ,
330
+ self .n_rep ,
331
+ )
237
332
238
333
def forward (
239
334
self ,
@@ -256,41 +351,8 @@ def forward(
256
351
257
352
if self .use_kv_cache :
258
353
assert input_pos is not None
259
-
260
- if not self .use_sdpa_with_kv_cache_op :
261
-
262
- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
263
- k = k .transpose (1 , 2 )
264
- v = v .transpose (1 , 2 )
265
-
266
- k , v = self .kv_cache .update (input_pos , k , v )
267
- mask = self .mask [None , None , input_pos ]
268
-
269
- k = k .repeat_interleave (self .n_rep , dim = 1 )
270
- v = v .repeat_interleave (self .n_rep , dim = 1 )
271
- y = F .scaled_dot_product_attention (
272
- q , k , v , attn_mask = mask , dropout_p = 0.0
273
- )
274
-
275
- y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
276
-
277
- y = self .wo (y )
278
- return y
279
- else :
280
- from .custom_ops .sdpa_with_kv_cache import sdpa_with_kv_cache # noqa
281
-
282
- output = torch .ops .llama .sdpa_with_kv_cache (
283
- q ,
284
- k ,
285
- v ,
286
- self .kv_cache .k_cache ,
287
- self .kv_cache .v_cache ,
288
- input_pos [- 1 ].item (),
289
- seqlen ,
290
- )
291
- output = output .view (bsz , seqlen , - 1 )
292
- output = self .wo (output )
293
- return output
354
+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
355
+ return self .wo (output )
294
356
295
357
q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
296
358
k = k .transpose (1 , 2 )
0 commit comments