-
Notifications
You must be signed in to change notification settings - Fork 34
Refactor tests for Flash Attention Prefill #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor tests for Flash Attention Prefill #446
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I am only reviewing the stuff that isn't in #443.
@@ -187,6 +252,27 @@ struct TestbedImpl { | |||
// Methods | |||
// | |||
|
|||
template <typename SrcT, typename DstT> | |||
void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming this PR lands after #351, we could use a single definition for convert_fp8_to_fp16
which I suggested to move to sycl_common.hpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rename it as convert_dtype
& move it to sycl_common.hpp
. Thanks!
using outType = cute::conditional_t<is_fp8_v<Tin>, half_t, Tin>; | ||
if constexpr(is_fp8_v<Tin>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using outType = cute::conditional_t<is_fp8_v<Tin>, half_t, Tin>; | |
if constexpr(is_fp8_v<Tin>) { | |
if constexpr(is_fp8_v<Tin>) { | |
using outType = half_t; |
* Add comment on final output type conversion
be16975
to
ab0e187
Compare
This PR separates the output type and accumulator type for Flash Attention Prefill. Combinations supported are: * bf16 inputs, fp32 accumulator, bf16 | fp32 output * fp16 inputs, fp32 accumulator, fp16 | fp32 output * fp8 inputs, fp32 accumulator, fp8 | fp32 output Tests added in: #446 Benchmarks added in: #447 --------- Co-authored-by: Alejandro Acosta <[email protected]>
…tlass-fork into flash_prefill_separate_out_tests
725aab4
into
codeplaysoftware:sycl-develop
This PR adds tests for all the different data types supported with Flash Attention Prefill. It is a continuation of PR #443