Skip to content

Commit d5f1886

Browse files
Separate output and accumulator type for Flash Attention Prefill (#443)
This PR separates the output type and accumulator type for Flash Attention Prefill. Combinations supported are: * bf16 inputs, fp32 accumulator, bf16 | fp32 output * fp16 inputs, fp32 accumulator, fp16 | fp32 output * fp8 inputs, fp32 accumulator, fp8 | fp32 output Tests added in: #446 Benchmarks added in: #447 --------- Co-authored-by: Alejandro Acosta <[email protected]>
1 parent c316fb5 commit d5f1886

File tree

7 files changed

+101
-52
lines changed

7 files changed

+101
-52
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ template <class DispatchPolicy, class MMAOperation_, class TileShapeOutput_, cla
5353
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
5454
};
5555

56-
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
57-
class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
56+
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementCompute_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
57+
class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementCompute_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
5858
public:
5959
//
6060
// Type Aliases
6161
//
6262
using DispatchPolicy = epilogue::IntelXeXMX16;
6363
using ElementO = ElementO_;
64-
using ElementAccumulator = ElementO_;
6564
using StrideO = StrideO_;
6665
using ElementLSE = ElementLSE_;
6766
using CopyOpO = CopyOpO_;
@@ -70,7 +69,8 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
7069
using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA;
7170
using GmemTiledCopyO = CopyOpO;
7271
using ElementOutput = ElementO_;
73-
using ElementCompute = ElementO_;
72+
using ElementCompute = ElementCompute_;
73+
using ElementAccumulator = ElementCompute_;
7474
using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape())));
7575

7676
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
@@ -196,7 +196,18 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
196196
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
197197
Tensor tOgO = thread_xe_store_o.partition_D(gO);
198198

199-
copy(params.xe_store_o, out_reg, tOgO);
199+
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);
200+
// iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion
201+
// iff ElementOutput == fp8, there is no NumericConverter specialization available
202+
// for both the above cases, we call copy() which internally performs a static_cast op on the data.
203+
// for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization.
204+
if constexpr (cute::is_any_of_v<ElementOutput, cute::float_e5m2_t, cute::float_e4m3_t> || cute::is_same_v<ElementOutput, ElementCompute>) {
205+
copy(out_reg, final_out_reg);
206+
} else {
207+
Tensor temp = convert_type<ElementOutput>(out_reg);
208+
copy(temp, final_out_reg);
209+
}
210+
copy(params.xe_store_o, final_out_reg, tOgO);
200211
}
201212

202213
// SequenceLengthShapeType = Shape<int, int>

applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class FMHAPrefill {
370370
CUTLASS_PRAGMA_UNROLL
371371
for (int row = 0; row < Vec; row++, row_idx++) { // 8
372372
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
373-
tSr(row, m, n) = -INFINITY;
373+
tSr(row, m, n) = ElementAccumulator{-INFINITY};
374374
}
375375
}
376376
}

benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,17 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
189189
}
190190
int kv_group_update=1;
191191
for (int h = 0; h < num_heads_q; h++) {
192-
cutlass::DeviceAllocation<ElementOutput> block_S;
192+
cutlass::DeviceAllocation<ElementAccumulator> block_S;
193193
block_S.reset(seq_len_qo * seq_len_kv);
194194

195195
cutlass::TensorRef ref_Q(block_Q[0].get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
196196
cutlass::TensorRef ref_K(block_K[0].get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv}));
197197
cutlass::TensorRef ref_V(block_V[0].get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo}));
198198
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
199-
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));
200199

201-
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q,
200+
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, ElementAccumulator{1.f}, ref_Q,
202201
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
203-
0.f, ref_S, ref_S, ElementAccumulator(0),
202+
ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0},
204203
1, // batch_count
205204
seq_len_qo * head_size_qk, // batch_stride_Q
206205
seq_len_kv * head_size_qk, // batch_stride_K
@@ -210,9 +209,8 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
210209

211210
syclcompat::wait();
212211

213-
std::vector<ElementOutput> host_S(block_S.size());
214-
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
215-
syclcompat::wait();
212+
std::vector<ElementAccumulator> host_S(block_S.size());
213+
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
216214

217215
// delete this memory as it is no longer needed
218216
block_S.reset();
@@ -224,13 +222,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
224222
for (int row = 0; row < seq_len_qo; row++) {
225223
for (int col = 0; col < seq_len_kv; col++) {
226224
if ((col - full_tile_offset) > (row - discard_seq_coord))
227-
host_S[col + row * seq_len_kv] = -INFINITY;
225+
host_S[col + row * seq_len_kv] = ElementAccumulator{-INFINITY};
228226
}
229227
}
230228
}
231229

232230
// compute max element per row of S
233-
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
231+
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
234232
for (int row = 0; row < seq_len_qo; row++) {
235233
int idx = row * seq_len_kv;
236234
int max_idx = row;
@@ -246,12 +244,12 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
246244
int idx = row * seq_len_kv;
247245
int max_idx = row;
248246
for (int col = 0; col < seq_len_kv; col++, idx++) {
249-
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementOutput>((head_size_qk))));
247+
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementAccumulator>((head_size_qk))));
250248
}
251249
}
252250

253251
// compute sum per row of S
254-
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
252+
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
255253
for (int row = 0; row < seq_len_qo; row++) {
256254
int idx = row * seq_len_kv;
257255
int sum_idx = row;
@@ -279,13 +277,16 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
279277
block_P.reset(host_P.size());
280278

281279
syclcompat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
282-
syclcompat::wait();
283280

284281
cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
285282

286-
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P,
283+
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
284+
block_acc.reset(seq_len_qo * head_size_vo);
285+
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));
286+
287+
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1}, ref_P,
287288
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
288-
0.f, ref_O, ref_O, ElementAccumulator(0),
289+
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
289290
1, // batch_count
290291
seq_len_qo * seq_len_kv, // batch_stride_P
291292
seq_len_kv * head_size_vo, // batch_stride_V
@@ -297,6 +298,17 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
297298
// delete this memory as it is no longer needed
298299
block_P.reset();
299300

301+
std::vector<ElementAccumulator> vec_acc(block_acc.size());
302+
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
303+
304+
// delete this memory as it is no longer needed
305+
block_acc.reset();
306+
std::vector<ElementOutput> vec_out(vec_acc.size());
307+
for(int i = 0; i < vec_out.size(); i++) {
308+
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
309+
}
310+
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
311+
300312
offset_q += seq_len_qo * head_size_qk;
301313
if(kv_group_update % q_group_size==0) {
302314
offset_k += seq_len_kv * head_size_qk;
@@ -311,7 +323,7 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
311323

312324
// Check if output from CUTLASS kernel and reference kernel are equal or not
313325
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
314-
block_O.size(), 0.5f, 0.5f);
326+
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});
315327

316328
return passed;
317329
}

benchmarks/flash_attention/flash_attention_prefill/fmha_prefill_configuration.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct FMHAPrefillConfig {
6767
using MMAOperation = typename MMAOP<GEMMDispatchPolicy, ElementInputType,ElementAccumulator>::Type;
6868
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
6969
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
70-
SubgroupLayout, ElementAccumulator,
70+
SubgroupLayout, ElementAccumulator, ElementOutputType,
7171
cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
7272
GmemTiledCopyO>;
7373

examples/sycl/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
230230
}
231231
int kv_group_update=1;
232232
for (int h = 0; h < num_heads_q; h++) {
233-
cutlass::DeviceAllocation<ElementOutput> block_S;
233+
cutlass::DeviceAllocation<ElementAccumulator> block_S;
234234
block_S.reset(seq_len_qo * seq_len_kv);
235235

236236
cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
237237
cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv}));
238238
cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo}));
239239
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
240-
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));
241240

242241
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q,
243242
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
@@ -251,9 +250,8 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
251250

252251
syclcompat::wait();
253252

254-
std::vector<ElementOutput> host_S(block_S.size());
255-
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
256-
syclcompat::wait();
253+
std::vector<ElementAccumulator> host_S(block_S.size());
254+
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
257255

258256
// delete this memory as it is no longer needed
259257
block_S.reset();
@@ -265,13 +263,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
265263
for (int row = 0; row < seq_len_qo; row++) {
266264
for (int col = 0; col < seq_len_kv; col++) {
267265
if ((col - full_tile_offset) > (row - discard_seq_coord))
268-
host_S[col + row * seq_len_kv] = -INFINITY;
266+
host_S[col + row * seq_len_kv] = ElementAccumulator{-INFINITY};
269267
}
270268
}
271269
}
272270

273271
// compute max element per row of S
274-
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
272+
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
275273
for (int row = 0; row < seq_len_qo; row++) {
276274
int idx = row * seq_len_kv;
277275
int max_idx = row;
@@ -287,12 +285,12 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
287285
int idx = row * seq_len_kv;
288286
int max_idx = row;
289287
for (int col = 0; col < seq_len_kv; col++, idx++) {
290-
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementOutput>((head_size_qk))));
288+
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementAccumulator>((head_size_qk))));
291289
}
292290
}
293291

294292
// compute sum per row of S
295-
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
293+
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
296294
for (int row = 0; row < seq_len_qo; row++) {
297295
int idx = row * seq_len_kv;
298296
int sum_idx = row;
@@ -320,13 +318,16 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
320318
block_P.reset(host_P.size());
321319

322320
syclcompat::memcpy<ElementV_>(block_P.get(), host_P.data(), host_P.size());
323-
syclcompat::wait();
324321

325322
cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
326323

327-
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P,
324+
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
325+
block_acc.reset(seq_len_qo * head_size_vo);
326+
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));
327+
328+
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1}, ref_P,
328329
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
329-
0.f, ref_O, ref_O, ElementAccumulator(0),
330+
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
330331
1, // batch_count
331332
seq_len_qo * seq_len_kv, // batch_stride_P
332333
seq_len_kv * head_size_vo, // batch_stride_V
@@ -338,6 +339,17 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
338339
// delete this memory as it is no longer needed
339340
block_P.reset();
340341

342+
std::vector<ElementAccumulator> vec_acc(block_acc.size());
343+
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
344+
345+
// delete this memory as it is no longer needed
346+
block_acc.reset();
347+
std::vector<ElementOutput> vec_out(vec_acc.size());
348+
for(int i = 0; i < vec_out.size(); i++) {
349+
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
350+
}
351+
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
352+
341353
offset_q += seq_len_qo * head_size_qk;
342354
if(kv_group_update % q_group_size==0) {
343355
offset_k += seq_len_kv * head_size_qk;
@@ -352,7 +364,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
352364

353365
// Check if output from CUTLASS kernel and reference kernel are equal or not
354366
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
355-
block_O.size(), 0.5f, 0.5f);
367+
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});
356368

357369
return passed;
358370
}
@@ -619,7 +631,7 @@ template <bool Causal,
619631
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
620632
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
621633
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
622-
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator, cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
634+
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
623635
GmemTiledCopyStore>;
624636
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue<Causal, EpilogueDispatchPolicy, ElementAccumulator>;
625637

include/cute/arch/copy_xe_U8.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ struct XE_2D_U8x4x16_ST_N {
682682
};
683683

684684
struct XE_2D_U8x8x16_ST_N {
685+
using BlockShape = Shape<_8, _16>;
685686
template <class T>
686687
CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height,
687688
int pitch, intel::coord_t coord,

0 commit comments

Comments
 (0)