@@ -230,14 +230,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
230
230
}
231
231
int kv_group_update=1 ;
232
232
for (int h = 0 ; h < num_heads_q; h++) {
233
- cutlass::DeviceAllocation<ElementOutput > block_S;
233
+ cutlass::DeviceAllocation<ElementAccumulator > block_S;
234
234
block_S.reset (seq_len_qo * seq_len_kv);
235
235
236
236
cutlass::TensorRef ref_Q (block_Q_.get () + offset_q, LayoutQ::packed ({seq_len_qo, head_size_qk}));
237
237
cutlass::TensorRef ref_K (block_K_.get () + offset_k, LayoutK::packed ({head_size_qk, seq_len_kv}));
238
238
cutlass::TensorRef ref_V (block_V_.get () + offset_v, LayoutV::packed ({seq_len_kv, head_size_vo}));
239
239
cutlass::TensorRef ref_S (block_S.get (), LayoutQ::packed ({seq_len_qo, seq_len_kv}));
240
- cutlass::TensorRef ref_O (block_ref_O.get () + offset_o, LayoutO::packed ({seq_len_qo, head_size_vo}));
241
240
242
241
cutlass::reference::device::GemmComplex ({seq_len_qo, seq_len_kv, head_size_qk}, 1 .f , ref_Q,
243
242
cutlass::ComplexTransform::kNone , ref_K, cutlass::ComplexTransform::kNone ,
@@ -251,9 +250,8 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
251
250
252
251
syclcompat::wait ();
253
252
254
- std::vector<ElementOutput> host_S (block_S.size ());
255
- syclcompat::memcpy<ElementOutput>(host_S.data (), block_S.get (), host_S.size ());
256
- syclcompat::wait ();
253
+ std::vector<ElementAccumulator> host_S (block_S.size ());
254
+ syclcompat::memcpy<ElementAccumulator>(host_S.data (), block_S.get (), host_S.size ());
257
255
258
256
// delete this memory as it is no longer needed
259
257
block_S.reset ();
@@ -265,13 +263,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
265
263
for (int row = 0 ; row < seq_len_qo; row++) {
266
264
for (int col = 0 ; col < seq_len_kv; col++) {
267
265
if ((col - full_tile_offset) > (row - discard_seq_coord))
268
- host_S[col + row * seq_len_kv] = -INFINITY;
266
+ host_S[col + row * seq_len_kv] = ElementAccumulator{ -INFINITY} ;
269
267
}
270
268
}
271
269
}
272
270
273
271
// compute max element per row of S
274
- std::vector<ElementOutput > max_vec (seq_len_qo, -INFINITY);
272
+ std::vector<ElementAccumulator > max_vec (seq_len_qo, ElementAccumulator{ -INFINITY} );
275
273
for (int row = 0 ; row < seq_len_qo; row++) {
276
274
int idx = row * seq_len_kv;
277
275
int max_idx = row;
@@ -287,12 +285,12 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
287
285
int idx = row * seq_len_kv;
288
286
int max_idx = row;
289
287
for (int col = 0 ; col < seq_len_kv; col++, idx++) {
290
- host_S[idx] = expf ((host_S[idx] - max_vec[max_idx]) / sqrt (static_cast <ElementOutput >((head_size_qk))));
288
+ host_S[idx] = expf ((host_S[idx] - max_vec[max_idx]) / sqrt (static_cast <ElementAccumulator >((head_size_qk))));
291
289
}
292
290
}
293
291
294
292
// compute sum per row of S
295
- std::vector<ElementOutput > sum_vec (seq_len_qo, ElementOutput {0 });
293
+ std::vector<ElementAccumulator > sum_vec (seq_len_qo, ElementAccumulator {0 });
296
294
for (int row = 0 ; row < seq_len_qo; row++) {
297
295
int idx = row * seq_len_kv;
298
296
int sum_idx = row;
@@ -320,13 +318,16 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
320
318
block_P.reset (host_P.size ());
321
319
322
320
syclcompat::memcpy<ElementV_>(block_P.get (), host_P.data (), host_P.size ());
323
- syclcompat::wait ();
324
321
325
322
cutlass::TensorRef ref_P (block_P.get (), LayoutQ::packed ({seq_len_qo, seq_len_kv}));
326
323
327
- cutlass::reference::device::GemmComplex ({seq_len_qo, head_size_vo, seq_len_kv}, 1 .f , ref_P,
324
+ cutlass::DeviceAllocation<ElementAccumulator> block_acc;
325
+ block_acc.reset (seq_len_qo * head_size_vo);
326
+ cutlass::TensorRef ref_acc (block_acc.get (), LayoutO::packed ({seq_len_qo, head_size_vo}));
327
+
328
+ cutlass::reference::device::GemmComplex ({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1 }, ref_P,
328
329
cutlass::ComplexTransform::kNone , ref_V, cutlass::ComplexTransform::kNone ,
329
- 0 . f , ref_O, ref_O , ElementAccumulator ( 0 ) ,
330
+ ElementAccumulator{ 0 }, ref_acc, ref_acc , ElementAccumulator{ 0 } ,
330
331
1 , // batch_count
331
332
seq_len_qo * seq_len_kv, // batch_stride_P
332
333
seq_len_kv * head_size_vo, // batch_stride_V
@@ -338,6 +339,17 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
338
339
// delete this memory as it is no longer needed
339
340
block_P.reset ();
340
341
342
+ std::vector<ElementAccumulator> vec_acc (block_acc.size ());
343
+ syclcompat::memcpy<ElementAccumulator>(vec_acc.data (), block_acc.get (), vec_acc.size ());
344
+
345
+ // delete this memory as it is no longer needed
346
+ block_acc.reset ();
347
+ std::vector<ElementOutput> vec_out (vec_acc.size ());
348
+ for (int i = 0 ; i < vec_out.size (); i++) {
349
+ vec_out[i] = static_cast <ElementOutput>(vec_acc[i]);
350
+ }
351
+ syclcompat::memcpy<ElementOutput>(block_ref_O.get () + offset_o, vec_out.data (), vec_out.size ());
352
+
341
353
offset_q += seq_len_qo * head_size_qk;
342
354
if (kv_group_update % q_group_size==0 ) {
343
355
offset_k += seq_len_kv * head_size_qk;
@@ -352,7 +364,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
352
364
353
365
// Check if output from CUTLASS kernel and reference kernel are equal or not
354
366
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual (block_ref_O.get (), block_O.get (),
355
- block_O.size (), 0 . 5f , 0 . 5f );
367
+ block_O.size (), ElementOutput{ 0.5 }, ElementOutput{ 0.5 } );
356
368
357
369
return passed;
358
370
}
@@ -619,7 +631,7 @@ template <bool Causal,
619
631
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
620
632
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
621
633
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
622
- EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator , cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
634
+ EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput , cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
623
635
GmemTiledCopyStore>;
624
636
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue<Causal, EpilogueDispatchPolicy, ElementAccumulator>;
625
637
0 commit comments