Skip to content

Refactor tests for Flash Attention Prefill #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

muhammad-tanvir-1211
Copy link
Collaborator

This PR adds tests for all the different data types supported with Flash Attention Prefill. It is a continuation of PR #443

Copy link
Collaborator

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I am only reviewing the stuff that isn't in #443.

@@ -187,6 +252,27 @@ struct TestbedImpl {
// Methods
//

template <typename SrcT, typename DstT>
void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this PR lands after #351, we could use a single definition for convert_fp8_to_fp16 which I suggested to move to sycl_common.hpp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll rename it as convert_dtype & move it to sycl_common.hpp. Thanks!

Comment on lines +266 to +267
using outType = cute::conditional_t<is_fp8_v<Tin>, half_t, Tin>;
if constexpr(is_fp8_v<Tin>) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using outType = cute::conditional_t<is_fp8_v<Tin>, half_t, Tin>;
if constexpr(is_fp8_v<Tin>) {
if constexpr(is_fp8_v<Tin>) {
using outType = half_t;

@muhammad-tanvir-1211 muhammad-tanvir-1211 force-pushed the flash_prefill_separate_out_tests branch from be16975 to ab0e187 Compare June 27, 2025 13:43
aacostadiaz added a commit that referenced this pull request Jun 28, 2025
This PR separates the output type and accumulator type for Flash
Attention Prefill. Combinations supported are:

* bf16 inputs, fp32 accumulator, bf16 | fp32 output
* fp16 inputs, fp32 accumulator, fp16 | fp32 output
* fp8 inputs, fp32 accumulator, fp8 | fp32 output

Tests added in: #446 
Benchmarks added in: #447

---------

Co-authored-by: Alejandro Acosta <[email protected]>
@aacostadiaz aacostadiaz merged commit 725aab4 into codeplaysoftware:sycl-develop Jun 30, 2025
28 of 54 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants