-
Notifications
You must be signed in to change notification settings - Fork 34
First version of FP8 scaled_mm. #428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cfgfung, I'll defer to the cutlass team for review, and will add some comments later.
But maybe you can change the state of this PR to draft to rebase later.
The corresponding integration code with IPEX is failing accuracy, BTW.
You may wish to follow up with IPEX team and check what is the root cause of the failing accuracy. This code passes the verify()/ unit test. I have granted the rights to you. You should be able to commit/change the code. |
* | ||
**************************************************************************************************/ | ||
/* | ||
* This implements the scaled_mm for W8A8 GEMM (FP8 weights and activations) using FP16 compute as a workaround, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is scaled_mm? Can you express that in CUTLASS terms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review.
This means matrix multiplication with scaling factors.
(scaleA.*TensorA) @ (scaleB.*TensorB)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cfgfung, please add info on the granularity of scales in the code as well as the PR description, since this scaling is specific to DeepSeek.
For example, for 128x128
A
& B
blocks, A
scale is applied at the granularity of per-token-per-128-channel sub-vector to an A
block, and B
scale is applied to the whole B
block.
For this PR, however, this granularity doesn't matter to the collective & the kernel, because the cutlass library user has to ensure that the scales of A
& B
are of the same size as A
& B
, so the scaling in this PR is elementwise.
I'll make it more generic in my subsequent PR, so that various workgroup tile-sizes could work with various quantization block sizes.
However, can you please rename the collective & kernel files to indicate that the quantization is supposed to be DeepSeek-style block-wise? In that case, I could modify the same files. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we call this scaled F8 GEMM? I do not think weights and activations are terms that are relevant on CUTLASS level. I am also not against adding more details about how scaling is done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @t4c1, this implementation requires users to pass scales at the granularity of each element of A
or B
matrices (so scale matrices also have to be of the same size as A
or B
). That makes this implementation quantization-scheme agnostic. It can't be used in real workloads, though.
|
||
template <typename SrcT, typename ScaleT> | ||
void elementwise_multiply_scale(SrcT* d_src, size_t size, ScaleT* d_scale){ | ||
SrcT* h_src_multiplied = new SrcT[size]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use std::vector instead of directly calling new
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sycl::memcpy() cannot take the vector as the argument.
I added delete[] instead to handle the memory issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sycl::memcpy() cannot take the vector as the argument
The pointer pertaining to a vector could have been used, though.
Regardless, it doesn't matter, as we shouldn't do this computation on CPU, so please change it to something like https://github.com/mehdi-goli/cutlass-fork/blob/884fa6a8b94adddba9f32b41b6a9d011e1642217/examples/sycl/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp#L202-L207
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment was not addressed. Is anything unclear here?
This will enable the W8A8 Block quantized matmul for DeepSeek-R1 model.
ef7ff20
to
2573e3b
Compare
@t4c1 @sanchitintel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some tests.
Also can you clarify in what way this differs from the changes from #450.
using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel | ||
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel | ||
using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements | ||
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements | |
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales B must match shape of the copy atom for B in the number of elements |
|
||
template <typename SrcT, typename ScaleT> | ||
void elementwise_multiply_scale(SrcT* d_src, size_t size, ScaleT* d_scale){ | ||
SrcT* h_src_multiplied = new SrcT[size]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment was not addressed. Is anything unclear here?
Closing this as it is duplicating the efforts with others. |
This is the initial implementation of scaled_mm(), enabling element-wise scaling for
A
&B
matrices (scales ofA
&B
have to be of the same size asA
orB
, so this implementation is agnostic of the quantization scheme).Based on the existing W8A8 normal GEMM, this version focuses on delivering core functionality.
@sanchitintel will use this PR as a foundation for further optimizations in subsequent PRs:
A
andB
. They are created in the FW side & passed to cutlass, so there's a creation overhead, the memory bandwidth requirements increases, and register spilling also happens. Ideally, the framework should pass scales as is.